Cours 10 : Régression logistique¶

Loïc Grobol lgrobol@parisnanterre.fr

2021-10-27

In [1]:
from IPython.display import display, Markdown
In [2]:
import numpy as np

Vectorisations arbitraires de documents¶

On a vu des façons de traiter des documents vus comme des sacs des mots en les représentant comme des vecteurs dont les coordonnées correspondaient à des nombres d'occurrences.

Mais on aimerait — entre autres — pouvoir travailler avec des représentations arbitraires, on peut par exemple imaginer vouloir représenter un document par ŀa polarité (au sens de l'analyse du sentiment) de ses mots.

🧠 Exo 🧠¶

1. Vectoriser un document¶

À l'aide d'un lexique de sentiment (par exemple VADER), écrivez une fonction qui prend en entrée un texte en anglais et renvoie sa représentation sous forme d'un vecteur de features à deux traits : polarité positive moyenne (la somme des polarités positives des mots qu'il contient divisée par sa longueur en nombre de mots) et polarité négative moyenne.

In [3]:
def read_vader(vader_path):
    pass  # À vous de jouer
In [4]:
def featurize(doc, lexicon):
    pass # À vous de jouer !
In [5]:
lexicon = read_vader("../../data/vader_lexicon.txt")
doc = "I came in in the middle of this film so I had no idea about any credits or even its title till I looked it up here, where I see that it has received a mixed reception by your commentators. I'm on the positive side regarding this film but one thing really caught my attention as I watched: the beautiful and sensitive score written in a Coplandesque Americana style. My surprise was great when I discovered the score to have been written by none other than John Williams himself. True he has written sensitive and poignant scores such as Schindler's List but one usually associates his name with such bombasticities as Star Wars. But in my opinion what Williams has written for this movie surpasses anything I've ever heard of his for tenderness, sensitivity and beauty, fully in keeping with the tender and lovely plot of the movie. And another recent score of his, for Catch Me if You Can, shows still more wit and sophistication. As to Stanley and Iris, I like education movies like How Green was my Valley and Konrack, that one with John Voigt and his young African American charges in South Carolina, and Danny deVito's Renaissance Man, etc. They tell a necessary story of intellectual and spiritual awakening, a story which can't be told often enough. This one is an excellent addition to that genre."
doc_features = featurize(doc, lexicon)
doc_features

🧠 Correction 1 🧠¶

On commence par recycler notre tokenizer/normaliseur

In [6]:
import re

def poor_mans_tokenizer_and_normalizer(s):
    return [w.lower() for w in re.split(r"\s|\W", s.strip()) if w and all(c.isalpha() for c in w)]

On lit le lexique

In [7]:
def read_vader(vader_path):
    res = dict()
    with open(vader_path) as in_stream:
        for row in in_stream:
            word, polarity, *_ = row.lstrip().split("\t", maxsplit=2)
            res[word] = float(polarity)
    return res
lexicon = read_vader("../../data/vader_lexicon.txt")
lexicon
Out[7]:
{'$:': -1.5,
 '%)': -0.4,
 '%-)': -1.5,
 '&-:': -0.4,
 '&:': -0.7,
 "( '}{' )": 1.6,
 '(%': -0.9,
 "('-:": 2.2,
 "(':": 2.3,
 '((-:': 2.1,
 '(*': 1.1,
 '(-%': -0.7,
 '(-*': 1.3,
 '(-:': 1.6,
 '(-:0': 2.8,
 '(-:<': -0.4,
 '(-:o': 1.5,
 '(-:O': 1.5,
 '(-:{': -0.1,
 '(-:|>*': 1.9,
 '(-;': 1.3,
 '(-;|': 2.1,
 '(8': 2.6,
 '(:': 2.2,
 '(:0': 2.4,
 '(:<': -0.2,
 '(:o': 2.5,
 '(:O': 2.5,
 '(;': 1.1,
 '(;<': 0.3,
 '(=': 2.2,
 '(?:': 2.1,
 '(^:': 1.5,
 '(^;': 1.5,
 '(^;0': 2.0,
 '(^;o': 1.9,
 '(o:': 1.6,
 ")':": -2.0,
 ")-':": -2.1,
 ')-:': -2.1,
 ')-:<': -2.2,
 ')-:{': -2.1,
 '):': -1.8,
 '):<': -1.9,
 '):{': -2.3,
 ');<': -2.6,
 '*)': 0.6,
 '*-)': 0.3,
 '*-:': 2.1,
 '*-;': 2.4,
 '*:': 1.9,
 '*<|:-)': 1.6,
 '*\\0/*': 2.3,
 '*^:': 1.6,
 ',-:': 1.2,
 "---'-;-{@": 2.3,
 '--<--<@': 2.2,
 '.-:': -1.2,
 '..###-:': -1.7,
 '..###:': -1.9,
 '/-:': -1.3,
 '/:': -1.3,
 '/:<': -1.4,
 '/=': -0.9,
 '/^:': -1.0,
 '/o:': -1.4,
 '0-8': 0.1,
 '0-|': -1.2,
 '0:)': 1.9,
 '0:-)': 1.4,
 '0:-3': 1.5,
 '0:03': 1.9,
 '0;^)': 1.6,
 '0_o': -0.3,
 '10q': 2.1,
 '1337': 2.1,
 '143': 3.2,
 '1432': 2.6,
 '14aa41': 2.4,
 '182': -2.9,
 '187': -3.1,
 '2g2b4g': 2.8,
 '2g2bt': -0.1,
 '2qt': 2.1,
 '3:(': -2.2,
 '3:)': 0.5,
 '3:-(': -2.3,
 '3:-)': -1.4,
 '4col': -2.2,
 '4q': -3.1,
 '5fs': 1.5,
 '8)': 1.9,
 '8-d': 1.7,
 '8-o': -0.3,
 '86': -1.6,
 '8d': 2.9,
 ':###..': -2.4,
 ':$': -0.2,
 ':&': -0.6,
 ":'(": -2.2,
 ":')": 2.3,
 ":'-(": -2.4,
 ":'-)": 2.7,
 ':(': -1.9,
 ':)': 2.0,
 ':*': 2.5,
 ':-###..': -2.5,
 ':-&': -0.5,
 ':-(': -1.5,
 ':-)': 1.3,
 ':-))': 2.8,
 ':-*': 1.7,
 ':-,': 1.1,
 ':-.': -0.9,
 ':-/': -1.2,
 ':-<': -1.5,
 ':-d': 2.3,
 ':-D': 2.3,
 ':-o': 0.1,
 ':-p': 1.5,
 ':-[': -1.6,
 ':-\\': -0.9,
 ':-c': -1.3,
 ':-|': -0.7,
 ':-||': -2.5,
 ':-Þ': 0.9,
 ':/': -1.4,
 ':3': 2.3,
 ':<': -2.1,
 ':>': 2.1,
 ':?)': 1.3,
 ':?c': -1.6,
 ':@': -2.5,
 ':d': 2.3,
 ':D': 2.3,
 ':l': -1.7,
 ':o': -0.4,
 ':p': 1.0,
 ':s': -1.2,
 ':[': -2.0,
 ':\\': -1.3,
 ':]': 2.2,
 ':^)': 2.1,
 ':^*': 2.6,
 ':^/': -1.2,
 ':^\\': -1.0,
 ':^|': -1.0,
 ':c': -2.1,
 ':c)': 2.0,
 ':o)': 2.1,
 ':o/': -1.4,
 ':o\\': -1.1,
 ':o|': -0.6,
 ':P': 1.4,
 ':{': -1.9,
 ':|': -0.4,
 ':}': 2.1,
 ':Þ': 1.1,
 ';)': 0.9,
 ';-)': 1.0,
 ';-*': 2.2,
 ';-]': 0.7,
 ';d': 0.8,
 ';D': 0.8,
 ';]': 0.6,
 ';^)': 1.4,
 '</3': -3.0,
 '<3': 1.9,
 '<:': 2.1,
 '<:-|': -1.4,
 '=)': 2.2,
 '=-3': 2.0,
 '=-d': 2.4,
 '=-D': 2.4,
 '=/': -1.4,
 '=3': 2.1,
 '=d': 2.3,
 '=D': 2.3,
 '=l': -1.2,
 '=\\': -1.2,
 '=]': 1.6,
 '=p': 1.3,
 '=|': -0.8,
 '>-:': -2.0,
 '>.<': -1.3,
 '>:': -2.1,
 '>:(': -2.7,
 '>:)': 0.4,
 '>:-(': -2.7,
 '>:-)': -0.4,
 '>:/': -1.6,
 '>:o': -1.2,
 '>:p': 1.0,
 '>:[': -2.1,
 '>:\\': -1.7,
 '>;(': -2.9,
 '>;)': 0.1,
 '>_>^': 2.1,
 '@:': -2.1,
 '@>-->--': 2.1,
 "@}-;-'---": 2.2,
 'aas': 2.5,
 'aayf': 2.7,
 'afu': -2.9,
 'alol': 2.8,
 'ambw': 2.9,
 'aml': 3.4,
 'atab': -1.9,
 'awol': -1.3,
 'ayc': 0.2,
 'ayor': -1.2,
 'aug-00': 0.3,
 'bfd': -2.7,
 'bfe': -2.6,
 'bff': 2.9,
 'bffn': 1.0,
 'bl': 2.3,
 'bsod': -2.2,
 'btd': -2.1,
 'btdt': -0.1,
 'bz': 0.4,
 'b^d': 2.6,
 'cwot': -2.3,
 "d-':": -2.5,
 'd8': -3.2,
 'd:': 1.2,
 'd:<': -3.2,
 'd;': -2.9,
 'd=': 1.5,
 'doa': -2.3,
 'dx': -3.0,
 'ez': 1.5,
 'fav': 2.0,
 'fcol': -1.8,
 'ff': 1.8,
 'ffs': -2.8,
 'fkm': -2.4,
 'foaf': 1.8,
 'ftw': 2.0,
 'fu': -3.7,
 'fubar': -3.0,
 'fwb': 2.5,
 'fyi': 0.8,
 'fysa': 0.4,
 'g1': 1.4,
 'gg': 1.2,
 'gga': 1.7,
 'gigo': -0.6,
 'gj': 2.0,
 'gl': 1.3,
 'gla': 2.5,
 'gn': 1.2,
 'gr8': 2.7,
 'grrr': -0.4,
 'gt': 1.1,
 'h&k': 2.3,
 'hagd': 2.2,
 'hagn': 2.2,
 'hago': 1.2,
 'hak': 1.9,
 'hand': 2.2,
 'heart': 3.2,
 'hearts': 3.3,
 'hho1/2k': 1.4,
 'hhoj': 2.0,
 'hhok': 0.9,
 'hugz': 2.0,
 'hi5': 1.9,
 'idk': -0.4,
 'ijs': 0.7,
 'ilu': 3.4,
 'iluaaf': 2.7,
 'ily': 3.4,
 'ily2': 2.6,
 'iou': 0.7,
 'iyq': 2.3,
 'j/j': 2.0,
 'j/k': 1.6,
 'j/p': 1.4,
 'j/t': -0.2,
 'j/w': 1.0,
 'j4f': 1.4,
 'j4g': 1.7,
 'jho': 0.8,
 'jhomf': 1.0,
 'jj': 1.0,
 'jk': 0.9,
 'jp': 0.8,
 'jt': 0.9,
 'jw': 1.6,
 'jealz': -1.2,
 'k4y': 2.3,
 'kfy': 2.3,
 'kia': -3.2,
 'kk': 1.5,
 'kmuf': 2.2,
 'l': 2.0,
 'l&r': 2.2,
 'laoj': 1.3,
 'lmao': 2.9,
 'lmbao': 1.8,
 'lmfao': 2.5,
 'lmso': 2.7,
 'lol': 1.8,
 'lolz': 2.7,
 'lts': 1.6,
 'ly': 2.6,
 'ly4e': 2.7,
 'lya': 3.3,
 'lyb': 3.0,
 'lyl': 3.1,
 'lylab': 2.7,
 'lylas': 2.6,
 'lylb': 1.6,
 'm8': 1.4,
 'mia': -1.2,
 'mml': 2.0,
 'mofo': -2.4,
 'muah': 2.3,
 'mubar': -1.0,
 'musm': 0.9,
 'mwah': 2.5,
 'n1': 1.9,
 'nbd': 1.3,
 'nbif': -0.5,
 'nfc': -2.7,
 'nfw': -2.4,
 'nh': 2.2,
 'nimby': -0.8,
 'nimjd': -0.7,
 'nimq': -0.2,
 'nimy': -1.4,
 'nitl': -1.5,
 'nme': -2.1,
 'noyb': -0.7,
 'np': 1.4,
 'ntmu': 1.4,
 'o-8': -0.5,
 'o-:': -0.3,
 'o-|': -1.1,
 'o.o': -0.8,
 'O.o': -0.6,
 'o.O': -0.6,
 'o:': -0.2,
 'o:)': 1.5,
 'o:-)': 2.0,
 'o:-3': 2.2,
 'o:3': 2.3,
 'o:<': -0.3,
 'o;^)': 1.6,
 'ok': 1.2,
 'o_o': -0.5,
 'O_o': -0.5,
 'o_O': -0.5,
 'pita': -2.4,
 'pls': 0.3,
 'plz': 0.3,
 'pmbi': 0.8,
 'pmfji': 0.3,
 'pmji': 0.7,
 'po': -2.6,
 'ptl': 2.6,
 'pu': -1.1,
 'qq': -2.2,
 'qt': 1.8,
 'r&r': 2.4,
 'rofl': 2.7,
 'roflmao': 2.5,
 'rotfl': 2.6,
 'rotflmao': 2.8,
 'rotflmfao': 2.5,
 'rotflol': 3.0,
 'rotgl': 2.9,
 'rotglmao': 1.8,
 's:': -1.1,
 'sapfu': -1.1,
 'sete': 2.8,
 'sfete': 2.7,
 'sgtm': 2.4,
 'slap': 0.6,
 'slaw': 2.1,
 'smh': -1.3,
 'snafu': -2.5,
 'sob': -1.0,
 'swak': 2.3,
 'tgif': 2.3,
 'thks': 1.4,
 'thx': 1.5,
 'tia': 2.3,
 'tmi': -0.3,
 'tnx': 1.1,
 'true': 1.8,
 'tx': 1.5,
 'txs': 1.1,
 'ty': 1.6,
 'tyvm': 2.5,
 'urw': 1.9,
 'vbg': 2.1,
 'vbs': 3.1,
 'vip': 2.3,
 'vwd': 2.6,
 'vwp': 2.1,
 'wag': -0.2,
 'wd': 2.7,
 'wilco': 0.9,
 'wp': 1.0,
 'wtf': -2.8,
 'wtg': 2.1,
 'wth': -2.4,
 'x-d': 2.6,
 'x-p': 1.7,
 'xd': 2.8,
 'xlnt': 3.0,
 'xoxo': 3.0,
 'xoxozzz': 2.3,
 'xp': 1.6,
 'xqzt': 1.6,
 'xtc': 0.8,
 'yolo': 1.1,
 'yoyo': 0.4,
 'yvw': 1.6,
 'yw': 1.8,
 'ywia': 2.5,
 'zzz': -1.2,
 '[-;': 0.5,
 '[:': 1.3,
 '[;': 1.0,
 '[=': 1.7,
 '\\-:': -1.0,
 '\\:': -1.0,
 '\\:<': -1.7,
 '\\=': -1.1,
 '\\^:': -1.3,
 '\\o/': 2.2,
 '\\o:': -1.2,
 ']-:': -2.1,
 ']:': -1.6,
 ']:<': -2.5,
 '^<_<': 1.4,
 '^urs': -2.8,
 'abandon': -1.9,
 'abandoned': -2.0,
 'abandoner': -1.9,
 'abandoners': -1.9,
 'abandoning': -1.6,
 'abandonment': -2.4,
 'abandonments': -1.7,
 'abandons': -1.3,
 'abducted': -2.3,
 'abduction': -2.8,
 'abductions': -2.0,
 'abhor': -2.0,
 'abhorred': -2.4,
 'abhorrent': -3.1,
 'abhors': -2.9,
 'abilities': 1.0,
 'ability': 1.3,
 'aboard': 0.1,
 'absentee': -1.1,
 'absentees': -0.8,
 'absolve': 1.2,
 'absolved': 1.5,
 'absolves': 1.3,
 'absolving': 1.6,
 'abuse': -3.2,
 'abused': -2.3,
 'abuser': -2.6,
 'abusers': -2.6,
 'abuses': -2.6,
 'abusing': -2.0,
 'abusive': -3.2,
 'abusively': -2.8,
 'abusiveness': -2.5,
 'abusivenesses': -3.0,
 'accept': 1.6,
 'acceptabilities': 1.6,
 'acceptability': 1.1,
 'acceptable': 1.3,
 'acceptableness': 1.3,
 'acceptably': 1.5,
 'acceptance': 2.0,
 'acceptances': 1.7,
 'acceptant': 1.6,
 'acceptation': 1.3,
 'acceptations': 0.9,
 'accepted': 1.1,
 'accepting': 1.6,
 'accepts': 1.3,
 'accident': -2.1,
 'accidental': -0.3,
 'accidentally': -1.4,
 'accidents': -1.3,
 'accomplish': 1.8,
 'accomplished': 1.9,
 'accomplishes': 1.7,
 'accusation': -1.0,
 'accusations': -1.3,
 'accuse': -0.8,
 'accused': -1.2,
 'accuses': -1.4,
 'accusing': -0.7,
 'ache': -1.6,
 'ached': -1.6,
 'aches': -1.0,
 'achievable': 1.3,
 'aching': -2.2,
 'acquit': 0.8,
 'acquits': 0.1,
 'acquitted': 1.0,
 'acquitting': 1.3,
 'acrimonious': -1.7,
 'active': 1.7,
 'actively': 1.3,
 'activeness': 0.6,
 'activenesses': 0.8,
 'actives': 1.1,
 'adequate': 0.9,
 'admirability': 2.4,
 'admirable': 2.6,
 'admirableness': 2.2,
 'admirably': 2.5,
 'admiral': 1.3,
 'admirals': 1.5,
 'admiralties': 1.6,
 'admiralty': 1.2,
 'admiration': 2.5,
 'admirations': 1.6,
 'admire': 2.1,
 'admired': 2.3,
 'admirer': 1.8,
 'admirers': 1.7,
 'admires': 1.5,
 'admiring': 1.6,
 'admiringly': 2.3,
 'admit': 0.8,
 'admits': 1.2,
 'admitted': 0.4,
 'admonished': -1.9,
 'adopt': 0.7,
 'adopts': 0.7,
 'adorability': 2.2,
 'adorable': 2.2,
 'adorableness': 2.5,
 'adorably': 2.1,
 'adoration': 2.9,
 'adorations': 2.2,
 'adore': 2.6,
 'adored': 1.8,
 'adorer': 1.7,
 'adorers': 2.1,
 'adores': 1.6,
 'adoring': 2.6,
 'adoringly': 2.4,
 'adorn': 0.9,
 'adorned': 0.8,
 'adorner': 1.3,
 'adorners': 0.9,
 'adorning': 1.0,
 'adornment': 1.3,
 'adornments': 0.8,
 'adorns': 0.5,
 'advanced': 1.0,
 'advantage': 1.0,
 'advantaged': 1.4,
 'advantageous': 1.5,
 'advantageously': 1.9,
 'advantageousness': 1.6,
 'advantages': 1.5,
 'advantaging': 1.6,
 'adventure': 1.3,
 'adventured': 1.3,
 'adventurer': 1.2,
 'adventurers': 0.9,
 'adventures': 1.4,
 'adventuresome': 1.7,
 'adventuresomeness': 1.3,
 'adventuress': 0.8,
 'adventuresses': 1.4,
 'adventuring': 2.3,
 'adventurism': 1.5,
 'adventurist': 1.4,
 'adventuristic': 1.7,
 'adventurists': 1.2,
 'adventurous': 1.4,
 'adventurously': 1.3,
 'adventurousness': 1.8,
 'adversarial': -1.5,
 'adversaries': -1.0,
 'adversary': -0.8,
 'adversative': -1.2,
 'adversatively': -0.1,
 'adversatives': -1.0,
 'adverse': -1.5,
 'adversely': -0.8,
 'adverseness': -0.6,
 'adversities': -1.5,
 'adversity': -1.8,
 'affected': -0.6,
 'affection': 2.4,
 'affectional': 1.9,
 'affectionally': 1.5,
 'affectionate': 1.9,
 'affectionately': 2.2,
 'affectioned': 1.8,
 'affectionless': -2.0,
 'affections': 1.5,
 'afflicted': -1.5,
 'affronted': 0.2,
 'aggravate': -2.5,
 'aggravated': -1.9,
 'aggravates': -1.9,
 'aggravating': -1.2,
 'aggress': -1.3,
 'aggressed': -1.4,
 'aggresses': -0.5,
 'aggressing': -0.6,
 'aggression': -1.2,
 'aggressions': -1.3,
 'aggressive': -0.6,
 'aggressively': -1.3,
 'aggressiveness': -1.8,
 'aggressivities': -1.4,
 'aggressivity': -0.6,
 'aggressor': -0.8,
 'aggressors': -0.9,
 'aghast': -1.9,
 'agitate': -1.7,
 'agitated': -2.0,
 'agitatedly': -1.6,
 'agitates': -1.4,
 'agitating': -1.8,
 'agitation': -1.0,
 'agitational': -1.2,
 'agitations': -1.3,
 'agitative': -1.3,
 'agitato': -0.1,
 'agitator': -1.4,
 'agitators': -2.1,
 'agog': 1.9,
 'agonise': -2.1,
 'agonised': -2.3,
 'agonises': -2.4,
 'agonising': -1.5,
 'agonize': -2.3,
 'agonized': -2.2,
 'agonizes': -2.3,
 'agonizing': -2.7,
 'agonizingly': -2.3,
 'agony': -1.8,
 'agree': 1.5,
 'agreeability': 1.9,
 'agreeable': 1.8,
 'agreeableness': 1.8,
 'agreeablenesses': 1.3,
 'agreeably': 1.6,
 'agreed': 1.1,
 'agreeing': 1.4,
 'agreement': 2.2,
 'agreements': 1.1,
 'agrees': 0.8,
 'alarm': -1.4,
 'alarmed': -1.4,
 'alarming': -0.5,
 'alarmingly': -2.6,
 'alarmism': -0.3,
 'alarmists': -1.1,
 'alarms': -1.1,
 'alas': -1.1,
 'alert': 1.2,
 'alienation': -1.1,
 'alive': 1.6,
 'allergic': -1.2,
 'allow': 0.9,
 'alone': -1.0,
 'alright': 1.0,
 'amaze': 2.5,
 'amazed': 2.2,
 'amazedly': 2.1,
 'amazement': 2.5,
 'amazements': 2.2,
 'amazes': 2.2,
 'amazing': 2.8,
 'amazon': 0.7,
 'amazonite': 0.2,
 'amazons': -0.1,
 'amazonstone': 1.0,
 'amazonstones': 0.2,
 'ambitious': 2.1,
 'ambivalent': 0.5,
 'amor': 3.0,
 'amoral': -1.6,
 'amoralism': -0.7,
 'amoralisms': -0.7,
 'amoralities': -1.2,
 'amorality': -1.5,
 'amorally': -1.0,
 'amoretti': 0.2,
 'amoretto': 0.6,
 'amorettos': 0.3,
 'amorino': 1.2,
 'amorist': 1.6,
 'amoristic': 1.0,
 'amorists': 0.1,
 'amoroso': 2.3,
 'amorous': 1.8,
 'amorously': 2.3,
 'amorousness': 2.0,
 'amorphous': -0.2,
 'amorphously': 0.1,
 'amorphousness': 0.3,
 'amort': -2.1,
 'amortise': 0.5,
 'amortised': -0.2,
 'amortises': 0.1,
 'amortizable': 0.5,
 'amortization': 0.6,
 'amortizations': 0.2,
 'amortize': -0.1,
 'amortized': 0.8,
 'amortizes': 0.6,
 'amortizing': 0.8,
 'amusable': 0.7,
 'amuse': 1.7,
 'amused': 1.8,
 'amusedly': 2.2,
 'amusement': 1.5,
 'amusements': 1.5,
 'amuser': 1.1,
 'amusers': 1.3,
 'amuses': 1.7,
 'amusia': 0.3,
 'amusias': -0.4,
 'amusing': 1.6,
 'amusingly': 0.8,
 'amusingness': 1.8,
 'amusive': 1.7,
 'anger': -2.7,
 'angered': -2.3,
 'angering': -2.2,
 'angerly': -1.9,
 'angers': -2.3,
 'angrier': -2.3,
 'angriest': -3.1,
 'angrily': -1.8,
 'angriness': -1.7,
 'angry': -2.3,
 'anguish': -2.9,
 'anguished': -1.8,
 'anguishes': -2.1,
 'anguishing': -2.7,
 'animosity': -1.9,
 'annoy': -1.9,
 'annoyance': -1.3,
 'annoyances': -1.8,
 'annoyed': -1.6,
 'annoyer': -2.2,
 'annoyers': -1.5,
 'annoying': -1.7,
 'annoys': -1.8,
 'antagonism': -1.9,
 'antagonisms': -1.2,
 'antagonist': -1.9,
 'antagonistic': -1.7,
 'antagonistically': -2.2,
 'antagonists': -1.7,
 'antagonize': -2.0,
 'antagonized': -1.4,
 'antagonizes': -0.5,
 'antagonizing': -2.7,
 'anti': -1.3,
 'anticipation': 0.4,
 'anxieties': -0.6,
 'anxiety': -0.7,
 'anxious': -1.0,
 'anxiously': -0.9,
 'anxiousness': -1.0,
 'aok': 2.0,
 'apathetic': -1.2,
 'apathetically': -0.4,
 'apathies': -0.6,
 'apathy': -1.2,
 'apeshit': -0.9,
 'apocalyptic': -3.4,
 'apologise': 1.6,
 'apologised': 0.4,
 'apologises': 0.8,
 'apologising': 0.2,
 'apologize': 0.4,
 'apologized': 1.3,
 'apologizes': 1.5,
 'apologizing': -0.3,
 'apology': 0.2,
 'appall': -2.4,
 'appalled': -2.0,
 'appalling': -1.5,
 'appallingly': -2.0,
 'appalls': -1.9,
 'appease': 1.1,
 'appeased': 0.9,
 'appeases': 0.9,
 'appeasing': 1.0,
 'applaud': 2.0,
 'applauded': 1.5,
 'applauding': 2.1,
 'applauds': 1.4,
 'applause': 1.8,
 'appreciate': 1.7,
 'appreciated': 2.3,
 'appreciates': 2.3,
 'appreciating': 1.9,
 'appreciation': 2.3,
 'appreciations': 1.7,
 'appreciative': 2.6,
 'appreciatively': 1.8,
 'appreciativeness': 1.6,
 'appreciator': 2.6,
 'appreciators': 1.5,
 'appreciatory': 1.7,
 'apprehensible': 1.1,
 'apprehensibly': -0.2,
 'apprehension': -2.1,
 'apprehensions': -0.9,
 'apprehensively': -0.3,
 'apprehensiveness': -0.7,
 'approval': 2.1,
 'approved': 1.8,
 'approves': 1.7,
 'ardent': 2.1,
 'arguable': -1.0,
 'arguably': -1.0,
 'argue': -1.4,
 'argued': -1.5,
 'arguer': -1.6,
 'arguers': -1.4,
 'argues': -1.6,
 'arguing': -2.0,
 'argument': -1.5,
 'argumentative': -1.5,
 'argumentatively': -1.8,
 'argumentive': -1.5,
 'arguments': -1.7,
 'arrest': -1.4,
 'arrested': -2.1,
 'arrests': -1.9,
 'arrogance': -2.4,
 'arrogances': -1.9,
 'arrogant': -2.2,
 'arrogantly': -1.8,
 'ashamed': -2.1,
 'ashamedly': -1.7,
 'ass': -2.5,
 'assassination': -2.9,
 'assassinations': -2.7,
 'assault': -2.8,
 'assaulted': -2.4,
 'assaulting': -2.3,
 'assaultive': -2.8,
 'assaults': -2.5,
 'asset': 1.5,
 'assets': 0.7,
 'assfucking': -2.5,
 'assholes': -2.8,
 'assurance': 1.4,
 'assurances': 1.4,
 'assure': 1.4,
 'assured': 1.5,
 'assuredly': 1.6,
 'assuredness': 1.4,
 'assurer': 0.9,
 'assurers': 1.1,
 'assures': 1.3,
 'assurgent': 1.3,
 'assuring': 1.6,
 'assuror': 0.5,
 'assurors': 0.7,
 'astonished': 1.6,
 'astound': 1.7,
 'astounded': 1.8,
 'astounding': 1.8,
 'astoundingly': 2.1,
 'astounds': 2.1,
 'attachment': 1.2,
 'attachments': 1.1,
 'attack': -2.1,
 'attacked': -2.0,
 'attacker': -2.7,
 'attackers': -2.7,
 'attacking': -2.0,
 'attacks': -1.9,
 'attract': 1.5,
 'attractancy': 0.9,
 'attractant': 1.3,
 'attractants': 1.4,
 'attracted': 1.8,
 'attracting': 2.1,
 'attraction': 2.0,
 'attractions': 1.8,
 'attractive': 1.9,
 'attractively': 2.2,
 'attractiveness': 1.8,
 'attractivenesses': 2.1,
 'attractor': 1.2,
 'attractors': 1.2,
 'attracts': 1.7,
 'audacious': 0.9,
 'authority': 0.3,
 'aversion': -1.9,
 'aversions': -1.1,
 'aversive': -1.6,
 'aversively': -0.8,
 'avert': -0.7,
 'averted': -0.3,
 'averts': -0.4,
 'avid': 1.2,
 'avoid': -1.2,
 'avoidance': -1.7,
 'avoidances': -1.1,
 'avoided': -1.4,
 'avoider': -1.8,
 'avoiders': -1.4,
 'avoiding': -1.4,
 'avoids': -0.7,
 'await': 0.4,
 'awaited': -0.1,
 'awaits': 0.3,
 'award': 2.5,
 'awardable': 2.4,
 'awarded': 1.7,
 'awardee': 1.8,
 'awardees': 1.2,
 'awarder': 0.9,
 'awarders': 1.3,
 'awarding': 1.9,
 'awards': 2.0,
 'awesome': 3.1,
 'awful': -2.0,
 'awkward': -0.6,
 'awkwardly': -1.3,
 'awkwardness': -0.7,
 'axe': -0.4,
 'axed': -1.3,
 'backed': 0.1,
 'backing': 0.1,
 'backs': -0.2,
 'bad': -2.5,
 'badass': 1.4,
 'badly': -2.1,
 'bailout': -0.4,
 'bamboozle': -1.5,
 'bamboozled': -1.5,
 'bamboozles': -1.5,
 'ban': -2.6,
 'banish': -1.9,
 'bankrupt': -2.6,
 'bankster': -2.1,
 'banned': -2.0,
 'bargain': 0.8,
 'barrier': -0.5,
 'bashful': -0.1,
 'bashfully': 0.2,
 'bashfulness': -0.8,
 'bastard': -2.5,
 'bastardies': -1.8,
 'bastardise': -2.1,
 'bastardised': -2.3,
 'bastardises': -2.3,
 'bastardising': -2.6,
 'bastardization': -2.4,
 'bastardizations': -2.1,
 'bastardize': -2.4,
 'bastardized': -2.0,
 'bastardizes': -1.8,
 'bastardizing': -2.3,
 'bastardly': -2.7,
 'bastards': -3.0,
 'bastardy': -2.7,
 'battle': -1.6,
 'battled': -1.2,
 'battlefield': -1.6,
 'battlefields': -0.9,
 'battlefront': -1.2,
 'battlefronts': -0.8,
 'battleground': -1.7,
 'battlegrounds': -0.6,
 'battlement': -0.4,
 'battlements': -0.4,
 'battler': -0.8,
 'battlers': -0.2,
 'battles': -1.6,
 'battleship': -0.1,
 'battleships': -0.5,
 'battlewagon': -0.3,
 'battlewagons': -0.5,
 'battling': -1.1,
 'beaten': -1.8,
 'beatific': 1.8,
 'beating': -2.0,
 'beaut': 1.6,
 'beauteous': 2.5,
 'beauteously': 2.6,
 ...}

Et voilà comment on récupère la représentation d'un document

In [8]:
def featurize(text, lexicon):
    words = poor_mans_tokenizer_and_normalizer(text)
    features = np.empty(2)
    # Le max permet de remonter les polarités négatives à 0
    features[0] = sum(max(lexicon.get(w, 0), 0) for w in words)/len(words)
    features[1] = sum(max(-lexicon.get(w, 0), 0) for w in words)/len(words)
    return features

On teste ?

In [9]:
doc = "I came in in the middle of this film so I had no idea about any credits or even its title till I looked it up here, where I see that it has received a mixed reception by your commentators. I'm on the positive side regarding this film but one thing really caught my attention as I watched: the beautiful and sensitive score written in a Coplandesque Americana style. My surprise was great when I discovered the score to have been written by none other than John Williams himself. True he has written sensitive and poignant scores such as Schindler's List but one usually associates his name with such bombasticities as Star Wars. But in my opinion what Williams has written for this movie surpasses anything I've ever heard of his for tenderness, sensitivity and beauty, fully in keeping with the tender and lovely plot of the movie. And another recent score of his, for Catch Me if You Can, shows still more wit and sophistication. As to Stanley and Iris, I like education movies like How Green was my Valley and Konrack, that one with John Voigt and his young African American charges in South Carolina, and Danny deVito's Renaissance Man, etc. They tell a necessary story of intellectual and spiritual awakening, a story which can't be told often enough. This one is an excellent addition to that genre."
doc_features = featurize(doc, lexicon)
doc_features
Out[9]:
array([0.12085106, 0.02085106])

2. Vectoriser un corpus¶

Appliquer la fonction précédente sur le mini-corpus IMDB

🧠 Correction 2 🧠¶

Commençons par l'extraire

In [10]:
%%bash
cd ../../local
tar -xzf ../data/imdb_smol.tar.gz 
ls -lah imdb_smol
tar: Ignoring unknown extended header keyword 'SCHILY.fflags'
tar: Ignoring unknown extended header keyword 'SCHILY.fflags'
tar: Ignoring unknown extended header keyword 'SCHILY.fflags'
total 32K
drwxr-xr-x 4 runner docker 4.0K Dec  4  2018 .
drwxr-xr-x 4 runner docker 4.0K Feb  7 19:26 ..
drwxr-xr-x 2 runner docker  12K Dec  4  2018 neg
drwxr-xr-x 2 runner docker  12K Dec  4  2018 pos

Maintenant on parcourt le dossier pour construire nos représentations

In [11]:
from collections import defaultdict
import pathlib  # Manipuler des chemins et des fichiers agréablement

def featurize_dir(corpus_root, lexicon):
    corpus_root = pathlib.Path(corpus_root)
    res = defaultdict(list)
    for clss in corpus_root.iterdir():
        # On peut aussi utiliser une compréhension de liste et avoir un dict pas default
        for doc in clss.iterdir():
            # `stem` et `read_text` c'est de la magie de `pathlib`, check it out
            res[clss.stem].append(featurize(doc.read_text(), lexicon))
    return res

# On réutilise le lexique précédent
imdb_features = featurize_dir("../../local/imdb_smol", lexicon)
imdb_features
Out[11]:
defaultdict(list,
            {'pos': [array([0.06206262, 0.15138122]),
              array([0.2625, 0.0775]),
              array([0.16530612, 0.05020408]),
              array([0.0439759 , 0.08162651]),
              array([0.06440177, 0.05317578]),
              array([0.18955224, 0.01641791]),
              array([0.11556684, 0.03739425]),
              array([0.07440476, 0.05297619]),
              array([0.08363636, 0.        ]),
              array([0.18913043, 0.09492754]),
              array([0.14527027, 0.05726351]),
              array([0.16593407, 0.04871795]),
              array([0.19492754, 0.04927536]),
              array([0.1587156 , 0.03990826]),
              array([0.17993197, 0.04217687]),
              array([0.09068736, 0.07117517]),
              array([0.06514523, 0.08921162]),
              array([0.14303797, 0.03797468]),
              array([0.20518519, 0.01777778]),
              array([0.08285714, 0.05785714]),
              array([0.19055118, 0.01102362]),
              array([0.07637795, 0.07165354]),
              array([0.29772727, 0.        ]),
              array([0.10900901, 0.02792793]),
              array([0.23209877, 0.00679012]),
              array([0.11280788, 0.0320197 ]),
              array([0.05220339, 0.05966102]),
              array([0.13240741, 0.03518519]),
              array([0.12521008, 0.03193277]),
              array([0.08402062, 0.0242268 ]),
              array([0.07474333, 0.06303901]),
              array([0.18455882, 0.        ]),
              array([0.1052459 , 0.09409836]),
              array([0.12185792, 0.10874317]),
              array([0.10520446, 0.07881041]),
              array([0.09520384, 0.01846523]),
              array([0.06045198, 0.09491525]),
              array([0.10887097, 0.03064516]),
              array([0.06982507, 0.05918367]),
              array([0.12360248, 0.02608696]),
              array([0.08244576, 0.11893491]),
              array([0.08871795, 0.05179487]),
              array([0.28333333, 0.        ]),
              array([0.11904762, 0.03285714]),
              array([0.12085106, 0.02085106]),
              array([0.15340909, 0.05170455]),
              array([0.06346154, 0.075     ]),
              array([0.18005698, 0.03447293]),
              array([0.05973597, 0.02112211]),
              array([0.16333333, 0.04388889]),
              array([0.11803279, 0.03442623]),
              array([0.08019802, 0.        ]),
              array([0.25777778, 0.07555556]),
              array([0.12985075, 0.01455224]),
              array([0.17478992, 0.        ]),
              array([0.21481481, 0.01358025]),
              array([0.16632997, 0.        ]),
              array([0.17894737, 0.02894737]),
              array([0.1968254 , 0.12063492]),
              array([0.04651163, 0.0627907 ]),
              array([0.25590551, 0.        ]),
              array([0.11265823, 0.05696203]),
              array([0.10625   , 0.01420455]),
              array([0.07647059, 0.12249135]),
              array([0.07575758, 0.02878788]),
              array([0.10504732, 0.1170347 ]),
              array([0.17, 0.  ]),
              array([0.12734375, 0.03164063]),
              array([0.12054795, 0.01917808]),
              array([0.09264892, 0.04626109]),
              array([0.05121951, 0.04731707]),
              array([0.12723005, 0.02253521]),
              array([0.057277  , 0.06737089]),
              array([0.07568058, 0.0553539 ]),
              array([0.15042017, 0.08319328]),
              array([0.09568627, 0.08058824]),
              array([0.15467836, 0.02387914]),
              array([0.05909091, 0.0385101 ]),
              array([0.22204082, 0.02897959]),
              array([0.03370787, 0.04269663]),
              array([0.24691358, 0.        ]),
              array([0.10547112, 0.04042553]),
              array([0.10923913, 0.07282609]),
              array([0.08321429, 0.09178571]),
              array([0.05347594, 0.08235294]),
              array([0.04957983, 0.04453782]),
              array([0.23857143, 0.00285714]),
              array([0.1480315 , 0.02834646]),
              array([0.21216216, 0.01283784]),
              array([0.16586103, 0.03897281]),
              array([0.28791209, 0.04505495]),
              array([0.11743421, 0.02565789]),
              array([0.05776398, 0.07267081]),
              array([0.09212598, 0.05984252]),
              array([0.13536585, 0.07378049]),
              array([0.07922849, 0.0379822 ]),
              array([0.08812785, 0.08721461]),
              array([0.22692308, 0.04      ]),
              array([0.13282443, 0.00916031]),
              array([0.09695817, 0.04638783]),
              array([0.1936255 , 0.06135458]),
              array([0.17019231, 0.05      ]),
              array([0.07070313, 0.059375  ]),
              array([0.09032258, 0.05096774]),
              array([0.18070175, 0.0622807 ]),
              array([0.13979592, 0.01326531]),
              array([0.12515723, 0.00754717]),
              array([0.07836257, 0.03391813]),
              array([0.14771838, 0.02685789]),
              array([0.08368794, 0.05035461]),
              array([0.12193959, 0.02925278]),
              array([0.08303249, 0.05090253]),
              array([0.08783069, 0.05661376]),
              array([0.07407407, 0.        ]),
              array([0.14641148, 0.03301435]),
              array([0.13382789, 0.09525223]),
              array([0.18062016, 0.08449612]),
              array([0.025     , 0.02785714]),
              array([0.13609467, 0.0260355 ]),
              array([0.13333333, 0.01833333]),
              array([0.08187919, 0.0261745 ]),
              array([0.11531532, 0.09369369]),
              array([0.08128342, 0.00374332]),
              array([0.1380597 , 0.01119403]),
              array([0.0569378 , 0.06220096]),
              array([0.0802139 , 0.07058824]),
              array([0.08717949, 0.03931624]),
              array([0.2119403 , 0.08656716]),
              array([0.07181208, 0.00805369]),
              array([0.06254296, 0.03883162]),
              array([0.0490566 , 0.06415094]),
              array([0.15692308, 0.05230769]),
              array([0.08983957, 0.12513369]),
              array([0.11176471, 0.15294118]),
              array([0.12947368, 0.03684211]),
              array([0.15744681, 0.03723404]),
              array([0.096875, 0.      ]),
              array([0.2462963, 0.       ]),
              array([0.16416667, 0.04083333]),
              array([0.2283871, 0.       ]),
              array([0.09496855, 0.07374214]),
              array([0.1       , 0.02355769]),
              array([0.19152542, 0.02711864]),
              array([0.15185185, 0.13111111]),
              array([0.11032967, 0.03802198]),
              array([0.08766234, 0.0512987 ]),
              array([0.11784777, 0.04356955]),
              array([0.07786885, 0.02213115]),
              array([0.12865014, 0.0661157 ]),
              array([0.09785203, 0.02673031]),
              array([0.24299065, 0.09906542]),
              array([0.26095238, 0.        ]),
              array([0.09246575, 0.09726027]),
              array([0.14294479, 0.02699387]),
              array([0.08505747, 0.02586207]),
              array([0.08963134, 0.10529954]),
              array([0.1503012 , 0.02409639]),
              array([0.07941176, 0.02352941]),
              array([0.13214286, 0.18214286]),
              array([0.11232877, 0.00684932]),
              array([0.07275748, 0.05813953]),
              array([0.0800995 , 0.06716418]),
              array([0.12371795, 0.0275641 ]),
              array([0.0908377 , 0.02041885]),
              array([0.19777778, 0.02755556]),
              array([0.11013825, 0.05898618]),
              array([0.03544669, 0.03602305]),
              array([0.09836957, 0.01630435]),
              array([0.11854305, 0.04635762]),
              array([0.25217391, 0.        ]),
              array([0.15580448, 0.0287169 ]),
              array([0.18409091, 0.07954545]),
              array([0.06923077, 0.00710059]),
              array([0.12105263, 0.01710526]),
              array([0.12910663, 0.02334294]),
              array([0.1352518 , 0.03453237]),
              array([0.11325758, 0.01799242]),
              array([0.07833333, 0.10916667]),
              array([0.14125, 0.04   ]),
              array([0.17794118, 0.09705882]),
              array([0.14201183, 0.02840237]),
              array([0.09414634, 0.01853659]),
              array([0.125     , 0.01111111]),
              array([0.1779661 , 0.02372881]),
              array([0.19411765, 0.        ]),
              array([0.06230769, 0.15      ]),
              array([0.11921922, 0.10930931]),
              array([0.09040404, 0.03030303]),
              array([0.071 , 0.1195]),
              array([0.09275093, 0.03773234]),
              array([0.31666667, 0.        ]),
              array([0.13467742, 0.        ]),
              array([0.14228188, 0.06979866]),
              array([0.0542654 , 0.10829384]),
              array([0.05200846, 0.02980973]),
              array([0.15061728, 0.        ]),
              array([0.08863636, 0.05170455]),
              array([0.09334862, 0.07591743]),
              array([0.21891892, 0.02972973]),
              array([0.17278912, 0.05578231]),
              array([0.16394558, 0.02040816]),
              array([0.08557692, 0.0375    ]),
              array([0.11912046, 0.11414914]),
              array([0.14387755, 0.09985423]),
              array([0.14545455, 0.12626263]),
              array([0.11871658, 0.        ]),
              array([0.12735043, 0.02307692]),
              array([0.08101266, 0.        ]),
              array([0.16168582, 0.0651341 ]),
              array([0.13469388, 0.04387755]),
              array([0.10526316, 0.03355263]),
              array([0.07537688, 0.0440536 ]),
              array([0.10753425, 0.0390411 ]),
              array([0.06015038, 0.12406015]),
              array([0.21573034, 0.07977528]),
              array([0.13956522, 0.01695652]),
              array([0.14444444, 0.06      ]),
              array([0.07794118, 0.04117647]),
              array([0.07329193, 0.19192547]),
              array([0.21582734, 0.00863309]),
              array([0.2630137 , 0.06575342]),
              array([0.13114754, 0.02540984]),
              array([0.3425, 0.0475]),
              array([0.21007194, 0.00863309]),
              array([0.13099415, 0.04619883]),
              array([0.09344262, 0.0442623 ]),
              array([0.15025641, 0.0974359 ]),
              array([0.14796748, 0.07723577]),
              array([0.14480519, 0.03896104]),
              array([0.09502262, 0.1239819 ]),
              array([0.21741071, 0.06428571]),
              array([0.05594315, 0.03682171]),
              array([0.07074341, 0.07122302]),
              array([0.30857143, 0.08285714]),
              array([0.16712329, 0.06780822]),
              array([0.07868852, 0.0795082 ]),
              array([0.14845361, 0.03814433]),
              array([0.14467213, 0.03770492]),
              array([0.12356322, 0.04827586]),
              array([0.06631944, 0.07118056]),
              array([0.18075802, 0.10233236]),
              array([0.09471154, 0.06923077]),
              array([0.10375   , 0.10479167]),
              array([0.04350649, 0.03636364]),
              array([0.03308271, 0.08220551]),
              array([0.06785714, 0.01919643]),
              array([0.06304348, 0.01884058]),
              array([0.11072555, 0.05315457]),
              array([0.11222222, 0.0062963 ]),
              array([0.1883871 , 0.01096774]),
              array([0.17255639, 0.01616541]),
              array([0.0659292 , 0.02389381]),
              array([0.07931034, 0.06206897]),
              array([0.10551724, 0.03862069]),
              array([0.09447005, 0.09124424]),
              array([0.10076923, 0.01615385]),
              array([0.1075    , 0.05928571]),
              array([0.22876712, 0.10547945]),
              array([0.19439252, 0.0411215 ]),
              array([0.0845, 0.055 ]),
              array([0.19510204, 0.0322449 ]),
              array([0.16160221, 0.03618785]),
              array([0.06806723, 0.00420168]),
              array([0.14375   , 0.13303571]),
              array([0.10882353, 0.00705882]),
              array([0.04705882, 0.06134454]),
              array([0.16907216, 0.02474227]),
              array([0.06122779, 0.06203554]),
              array([0.12322946, 0.11926346]),
              array([0.17945205, 0.01438356]),
              array([0.11144578, 0.06144578]),
              array([0.11434978, 0.06457399]),
              array([0.09368421, 0.03578947]),
              array([0.05578831, 0.0446527 ]),
              array([0.19302326, 0.04651163]),
              array([0.11752577, 0.        ]),
              array([0.05963855, 0.06144578]),
              array([0.12857143, 0.02631579]),
              array([0.11642857, 0.        ]),
              array([0.24351852, 0.01018519]),
              array([0.1023622, 0.0496063]),
              array([0.0671875 , 0.06953125]),
              array([0.09518717, 0.0802139 ]),
              array([0.19657534, 0.01712329]),
              array([0.16100629, 0.02830189]),
              array([0.14273504, 0.08290598]),
              array([0.12677165, 0.07795276]),
              array([0.15727273, 0.02363636]),
              array([0.05532359, 0.14467641]),
              array([0.15793651, 0.        ]),
              array([0.10192308, 0.05480769]),
              array([0.26923077, 0.07948718]),
              array([0.11712707, 0.01160221]),
              array([0.07741935, 0.07177419]),
              array([0.15747126, 0.01896552]),
              array([0.19230769, 0.05664336]),
              array([0.17153846, 0.        ]),
              array([0.09550562, 0.04719101]),
              array([0.07593985, 0.11879699]),
              array([0.06808511, 0.09219858]),
              array([0.21083333, 0.03583333])],
             'neg': [array([0.07728707, 0.11293375]),
              array([0.07990196, 0.08578431]),
              array([0.028125, 0.046875]),
              array([0.08676471, 0.08529412]),
              array([0.05379747, 0.05063291]),
              array([0.1260274 , 0.07123288]),
              array([0.06151203, 0.09965636]),
              array([0.08092486, 0.05144509]),
              array([0.0625    , 0.12053571]),
              array([0.11463415, 0.02439024]),
              array([0.08873239, 0.14929577]),
              array([0.16389776, 0.05015974]),
              array([0.0584507, 0.0415493]),
              array([0.03216374, 0.06959064]),
              array([0.02387097, 0.08580645]),
              array([0.12783505, 0.1       ]),
              array([0.1       , 0.02133333]),
              array([0.03      , 0.17642857]),
              array([0.10151515, 0.03151515]),
              array([0.05377834, 0.07619647]),
              array([0.08870968, 0.1       ]),
              array([0.06941176, 0.03411765]),
              array([0.05086705, 0.03583815]),
              array([0.11067416, 0.0258427 ]),
              array([0.03659849, 0.05489774]),
              array([0.14403131, 0.05303327]),
              array([0.15243446, 0.06853933]),
              array([0.07687861, 0.07687861]),
              array([0.02070175, 0.12736842]),
              array([0.04756757, 0.08054054]),
              array([0.09326923, 0.04150641]),
              array([0.04769231, 0.09076923]),
              array([0.09465649, 0.08854962]),
              array([0.09303797, 0.04493671]),
              array([0.033125, 0.081875]),
              array([0.11538462, 0.07307692]),
              array([0.07286822, 0.06821705]),
              array([0.12032967, 0.08571429]),
              array([0.0862069 , 0.09655172]),
              array([0.11415929, 0.15132743]),
              array([0.09190939, 0.08187702]),
              array([0.1847561 , 0.04146341]),
              array([0.07086093, 0.06423841]),
              array([0.05      , 0.16891892]),
              array([0.0875    , 0.05208333]),
              array([0.07731959, 0.0371134 ]),
              array([0.06234568, 0.02777778]),
              array([0.21219512, 0.0300813 ]),
              array([0.09205298, 0.04834437]),
              array([0.14615385, 0.        ]),
              array([0.14760563, 0.05211268]),
              array([0.08058252, 0.04029126]),
              array([0.06788321, 0.06788321]),
              array([0.08316498, 0.05319865]),
              array([0.15795455, 0.07159091]),
              array([0.08527607, 0.06687117]),
              array([0.0974359 , 0.10512821]),
              array([0.0620155 , 0.15426357]),
              array([0.        , 0.15060241]),
              array([0.07289294, 0.05649203]),
              array([0.08176101, 0.02044025]),
              array([0.09212121, 0.12484848]),
              array([0.0467033 , 0.14395604]),
              array([0.10069444, 0.12152778]),
              array([0.11373391, 0.0583691 ]),
              array([0.11082803, 0.05286624]),
              array([0.08 , 0.036]),
              array([0.08177083, 0.08489583]),
              array([0.0377193 , 0.14473684]),
              array([0.05038168, 0.10763359]),
              array([0.06311475, 0.1204918 ]),
              array([0.11492537, 0.08955224]),
              array([0.08253968, 0.0952381 ]),
              array([0.11869919, 0.04065041]),
              array([0.10766871, 0.04202454]),
              array([0.06896552, 0.03908046]),
              array([0.07264706, 0.07235294]),
              array([0.08119658, 0.02606838]),
              array([0.11384615, 0.08307692]),
              array([0.065     , 0.01714286]),
              array([0.0753915 , 0.11923937]),
              array([0.08024194, 0.09475806]),
              array([0.23488372, 0.14883721]),
              array([0.06080586, 0.05714286]),
              array([0.04367089, 0.18797468]),
              array([0.13371648, 0.04578544]),
              array([0.0528, 0.1416]),
              array([0.03584416, 0.18597403]),
              array([0.13266332, 0.10753769]),
              array([0.05186441, 0.03966102]),
              array([0.05391705, 0.12327189]),
              array([0.1204, 0.0136]),
              array([0.07103448, 0.11241379]),
              array([0.07626263, 0.07070707]),
              array([0.09921875, 0.0109375 ]),
              array([0.07248062, 0.15484496]),
              array([0.10595745, 0.07234043]),
              array([0.08267717, 0.0511811 ]),
              array([0.11519435, 0.06007067]),
              array([0.0704797 , 0.09372694]),
              array([0.05873016, 0.05952381]),
              array([0.12542373, 0.01101695]),
              array([0.048659  , 0.13295019]),
              array([0.11230769, 0.02730769]),
              array([0.05471698, 0.09119497]),
              array([0.03333333, 0.12345679]),
              array([0.10067568, 0.04662162]),
              array([0.11703057, 0.04759825]),
              array([0.08601399, 0.14545455]),
              array([0.08706468, 0.09303483]),
              array([0.10877193, 0.0754386 ]),
              array([0.0992674 , 0.04798535]),
              array([0.05897436, 0.17820513]),
              array([0.03522727, 0.11818182]),
              array([0.08445596, 0.12590674]),
              array([0.07      , 0.07916667]),
              array([0.05985401, 0.08029197]),
              array([0.08237705, 0.14467213]),
              array([0.1484127 , 0.00952381]),
              array([0.08481973, 0.09316888]),
              array([0.08531469, 0.04265734]),
              array([0.08109244, 0.09663866]),
              array([0.02774194, 0.04129032]),
              array([0.06444444, 0.05703704]),
              array([0.15630631, 0.06306306]),
              array([0.14      , 0.10378378]),
              array([0.05597579, 0.08517398]),
              array([0.03333333, 0.1202381 ]),
              array([0.13939394, 0.13939394]),
              array([0.16075949, 0.        ]),
              array([0.09545455, 0.04090909]),
              array([0.11639344, 0.05737705]),
              array([0.09134615, 0.07403846]),
              array([0.02130178, 0.13668639]),
              array([0.14324324, 0.11959459]),
              array([0.07305389, 0.08502994]),
              array([0.075     , 0.11470588]),
              array([0.10714286, 0.05879121]),
              array([0.06893204, 0.13883495]),
              array([0.06682692, 0.10625   ]),
              array([0.06751825, 0.05930657]),
              array([0.06538462, 0.04505495]),
              array([0.08421053, 0.09398496]),
              array([0.07276119, 0.05932836]),
              array([0.16551724, 0.03448276]),
              array([0.14132231, 0.09504132]),
              array([0.0719697 , 0.07272727]),
              array([0.0588993 , 0.10819672]),
              array([0.07914894, 0.08382979]),
              array([0.02407407, 0.06234568]),
              array([0.08924051, 0.05949367]),
              array([0.03781513, 0.1802521 ]),
              array([0.05804598, 0.10028736]),
              array([0.07361111, 0.05625   ]),
              array([0.129, 0.104]),
              array([0.16389776, 0.05015974]),
              array([0.12078652, 0.08370787]),
              array([0.0852349 , 0.03825503]),
              array([0.07410072, 0.09100719]),
              array([0.05095541, 0.02611465]),
              array([0.06535088, 0.20701754]),
              array([0.1084507 , 0.16478873]),
              array([0.06048387, 0.05645161]),
              array([0.04841629, 0.05067873]),
              array([0.08092105, 0.07434211]),
              array([0.10310078, 0.03255814]),
              array([0.10545455, 0.08484848]),
              array([0.10229008, 0.15877863]),
              array([0.06884735, 0.04548287]),
              array([0.04509804, 0.1       ]),
              array([0.05238095, 0.03333333]),
              array([0.05206612, 0.06446281]),
              array([0.06711712, 0.15630631]),
              array([0.02402827, 0.04240283]),
              array([0.25203252, 0.01626016]),
              array([0.08  , 0.1112]),
              array([0.21219512, 0.0695122 ]),
              array([0.09256198, 0.08512397]),
              array([0.04317181, 0.03612335]),
              array([0.09816176, 0.02536765]),
              array([0.08763441, 0.02204301]),
              array([0.06610169, 0.11186441]),
              array([0.04285714, 0.07278912]),
              array([0.09404762, 0.03333333]),
              array([0.04130435, 0.18188406]),
              array([0.07928177, 0.09143646]),
              array([0.11493506, 0.04967532]),
              array([0.06666667, 0.16769231]),
              array([0.05680851, 0.07361702]),
              array([0.11400651, 0.05635179]),
              array([0.05454545, 0.09521531]),
              array([0.07486631, 0.13315508]),
              array([0.06448911, 0.0881072 ]),
              array([0.01448276, 0.06413793]),
              array([0.07291667, 0.07152778]),
              array([0.12291667, 0.09375   ]),
              array([0.08074534, 0.14658385]),
              array([0.09454545, 0.32909091]),
              array([0.14025974, 0.09480519]),
              array([0.08979592, 0.11088435]),
              array([0.09 , 0.054]),
              array([0.103125 , 0.0203125]),
              array([0.05423729, 0.10451977]),
              array([0.05420168, 0.07647059]),
              array([0.04506173, 0.09444444]),
              array([0.06938776, 0.05918367]),
              array([0.06493506, 0.10649351]),
              array([0.05751634, 0.04248366]),
              array([0.05185185, 0.06157407]),
              array([0.09124088, 0.10145985]),
              array([0.08333333, 0.07463768]),
              array([0.10916667, 0.08      ]),
              array([0.05375 , 0.110625]),
              array([0.14935065, 0.05649351]),
              array([0.05768194, 0.07601078]),
              array([0.04839858, 0.11281139]),
              array([0.04791667, 0.06041667]),
              array([0.09066667, 0.104     ]),
              array([0.10334448, 0.10535117]),
              array([0.10422078, 0.0521645 ]),
              array([0.03503185, 0.04713376]),
              array([0.07116564, 0.08527607]),
              array([0.04112903, 0.07096774]),
              array([0.11818182, 0.0719697 ]),
              array([0.12908587, 0.05900277]),
              array([0.04336283, 0.15132743]),
              array([0.16884615, 0.07115385]),
              array([0.08878505, 0.0517757 ]),
              array([0.06502947, 0.11591356]),
              array([0.06363636, 0.02289562]),
              array([0.14766355, 0.07196262]),
              array([0.16165644, 0.09141104]),
              array([0.06142433, 0.09139466]),
              array([0.05146199, 0.12923977]),
              array([0.06564417, 0.05368098]),
              array([0.11086957, 0.05362319]),
              array([0.        , 0.06695652]),
              array([0.12711864, 0.06779661]),
              array([0.09637097, 0.08991935]),
              array([0.07826087, 0.03      ]),
              array([0.07666667, 0.06777778]),
              array([0.08129496, 0.04532374]),
              array([0.1  , 0.075]),
              array([0.        , 0.02978723]),
              array([0.11755424, 0.02879684]),
              array([0.15572917, 0.02864583]),
              array([0.05      , 0.15526316]),
              array([0.11294118, 0.09529412]),
              array([0.12515593, 0.07900208]),
              array([0.03014706, 0.0875    ]),
              array([0.08882979, 0.02659574]),
              array([0.04521739, 0.06695652]),
              array([0.05864979, 0.06793249]),
              array([0.0546798 , 0.05024631]),
              array([0.13630137, 0.07688356]),
              array([0.05956284, 0.12622951]),
              array([0.05775862, 0.09568966]),
              array([0.07468354, 0.05759494]),
              array([0.08796992, 0.07218045]),
              array([0.0605042 , 0.21176471]),
              array([0.06796875, 0.0484375 ]),
              array([0.05429363, 0.07174515]),
              array([0.13225806, 0.06693548]),
              array([0.08275862, 0.06767241]),
              array([0.05699482, 0.09222798]),
              array([0.12773723, 0.07153285]),
              array([0.05487805, 0.15426829]),
              array([0.12848101, 0.13734177]),
              array([0.09626168, 0.03714953]),
              array([0.05943396, 0.1       ]),
              array([0.14333333, 0.06333333]),
              array([0.05925926, 0.09074074]),
              array([0.13613445, 0.07142857]),
              array([0.07878788, 0.13484848]),
              array([0.056875, 0.039375]),
              array([0.08070175, 0.05112782]),
              array([0.07509158, 0.14761905]),
              array([0.025     , 0.12974138]),
              array([0.10352113, 0.02042254]),
              array([0.066875, 0.086875]),
              array([0.06412214, 0.06793893]),
              array([0.06334107, 0.07030162]),
              array([0.07927928, 0.01396396]),
              array([0.07583333, 0.16916667]),
              array([0.07171717, 0.35353535]),
              array([0.06340852, 0.07142857]),
              array([0.08670213, 0.04946809]),
              array([0.08679245, 0.19622642]),
              array([0.11677852, 0.06040268]),
              array([0.09769231, 0.05461538]),
              array([0.02792793, 0.07837838]),
              array([0.04576271, 0.08813559]),
              array([0.0437037 , 0.13037037]),
              array([0.06733333, 0.        ]),
              array([0.08455882, 0.10514706]),
              array([0.1392638 , 0.05521472]),
              array([0.11245421, 0.03113553]),
              array([0.13802281, 0.04081115]),
              array([0.11954023, 0.09655172]),
              array([0.09415205, 0.02690058]),
              array([0.10855615, 0.0368984 ])]})

Visualisation¶

Comment se répartissent les documents du corpus avec la représentation qu'on a choisi

In [12]:
import matplotlib.pyplot as plt
import seaborn as sns

X = np.array([d[0] for d in (*imdb_features["pos"], *imdb_features["neg"])])
Y = np.array([d[1] for d in (*imdb_features["pos"], *imdb_features["neg"])])
H = np.array([*("pos" for _ in imdb_features["pos"]), *("neg" for _ in imdb_features["neg"])])

fig = plt.figure(dpi=200)
sns.scatterplot(x=X, y=Y, hue=H, s=5)
plt.show()

On voit des tendances qui se dégagent, mais clairement ça va être un peu coton

Classifieur linéaire¶

On considère des vecteurs de features de dimension $n$

$$\mathbf{x} = (x₁, …, x_n)$$

Un vecteur de poids de dimension $n$

$$\mathbf{w} = (w₁, …, w_n)$$

et un biais $b$ scalaire (un nombre quoi).

Pour réaliser une classification on considère le nombre $z$ (on parle parfois de logit)

$$z=w₁×x₁ + … + w_n×x_n + b = \sum_iw_ix_i + b$$

Ce qu'on note aussi

$$z = \mathbf{w}⋅\mathbf{x}+b$$

$\mathbf{w}⋅\mathbf{x}$ se lit « w scalaire x », on parle de produit scalaire en français et de inner product en anglais.

(ou pour les mathématicien⋅ne⋅s acharné⋅e⋅s $z = \langle w\ |\ x \rangle + b$)

Quelle que soit la façon dont on le note, on affectera à $\mathbf{x}$ la classe $0$ si $z < 0$ et la classe $1$ sinon.

😴 Exo 😴¶

1. Une fonction affine¶

Écrire une fonction qui prend en entrée un vecteur de features et un vecteur de poids sous forme de tableaux numpy $x$ et $w$ de dimensions (n,) et un biais $b$ sous forme d'un tableau numpy de dimensions (1,) et renvoie $z=\sum_iw_ix_i + b$.

In [13]:
def affine_combination(x, w, b):
    pass # À vous de jouer !

affine_combination(
    np.array([2, 0, 2, 1]),
    np.array([-0.2, 999.1, 0.5, 2]),
    np.array([1]),
)

😴 Correction 1 😴¶

Une version élémentaire avec des boucles

In [14]:
def affine_combination(x, w, b):
    res = np.zeros(1)
    for wi, xi in zip(w, x):
        res += wi*xi
    res += b
    return res

affine_combination(
    np.array([2, 0, 2, 1]),
    np.array([-0.2, 999.1, 0.5, 2]),
    np.array([1]),
)
Out[14]:
array([3.6])

Une version plus courte avec les fonctions natives de numpy

In [15]:
def affine_combination(x, w, b):
    return np.inner(w, x) + b

affine_combination(
    np.array([2, 0, 2, 1]),
    np.array([-0.2, 999.1, 0.5, 2]),
    np.array([1]),
)
Out[15]:
array([3.6])

2. Un classifieur linéaire¶

Écrire un classifieur linéaire qui prend en entrée des vecteurs de features à deux dimensions précédents et utilise les poids respectifs $0.6$ et $-0.4$ et un biais de $-0.01$. Appliquez ce classifieur sur le mini-corpus IMDB qu'on a vectorisé et calculez son exactitude.

In [16]:
def hardcoded_classifier(x):
    return False  # À vous de jouer

hardcoded_classifier(doc_features)
Out[16]:
False

😴 Correction 2 😴¶

On commence par définir le classifieur : on va renvoyer True pour la classe positive et False pour la classe négative.

In [17]:
def hardcoded_classifier(x):
    return affine_combination(x, np.array([0.6, -0.4]), -0.01) > 0.0

hardcoded_classifier(doc_features)
Out[17]:
True

Maintenant on le teste

In [18]:
correct_pos = sum(1 for doc in imdb_features["pos"] if hardcoded_classifier(doc))
print(f"Recall for 'pos': {correct_pos}/{len(imdb_features['pos'])}={correct_pos/len(imdb_features['pos']):.02%}")
correct_neg = sum(1 for doc in imdb_features["neg"] if not hardcoded_classifier(doc))
print(f"Recall for 'neg': {correct_neg}/{len(imdb_features['neg'])}={correct_neg/len(imdb_features['neg']):.02%}")
print(f"Accuracy: {correct_pos+correct_neg}/{len(imdb_features['pos'])+len(imdb_features['neg'])}={(correct_pos+correct_neg)/(len(imdb_features['pos'])+len(imdb_features['neg'])):.02%}")
Recall for 'pos': 269/301=89.37%
Recall for 'neg': 118/301=39.20%
Accuracy: 387/602=64.29%

On en fait une fonction, ça nous sera utile plus tard

In [19]:
def classifier_accuracy(w, b, featurized_corpus):
    correct_pos = sum(1 for doc in imdb_features["pos"] if affine_combination(doc, w, b) > 0.0)
    correct_neg = sum(1 for doc in imdb_features["neg"] if affine_combination(doc, w, b) <= 0.0)
    return (correct_pos+correct_neg)/(len(featurized_corpus['pos'])+len(featurized_corpus['neg']))
classifier_accuracy(np.array([0.6, -0.4]), np.array(-0.01), imdb_features)
Out[19]:
0.6428571428571429

Classifieur linéaire ?¶

Pourquoi linéaire ? Regardez la figure suivante qui colore les points $(x,y)$ du plan en fonction de la valeur de $z$.

In [20]:
import tol_colors as tc

x = np.linspace(0, 1, 1000)
y = np.linspace(0, 1, 1000)
X, Y = np.meshgrid(x, y)
Z = (0.6*X - 0.4*Y) - 0.01

fig = plt.figure(dpi=200)

heatmap = plt.pcolormesh(X, Y, Z, shading="auto", cmap=tc.tol_cmap("sunset"))
plt.colorbar(heatmap)
plt.show()

Ou encore plus clairement, si on représente la classe assignée

In [21]:
import tol_colors as tc

x = np.linspace(0, 1, 1000)
y = np.linspace(0, 1, 1000)
X, Y = np.meshgrid(x, y)
Z = (0.6*X - 0.4*Y) -0.01 > 0.0

fig = plt.figure(dpi=200)

heatmap = plt.pcolormesh(X, Y, Z, shading="auto", cmap=tc.tol_cmap("sunset"))
plt.colorbar(heatmap)
plt.show()

On voit bien que la frontière de classification est une droite, a line. On a donc un linear classifier : un classifieur linéaire (même si en français on dirait qu'il s'agit d'une fonction affine).

Qu'est-ce que ça donne si on superpose avec notre corpus ?

In [22]:
fig = plt.figure(dpi=200)

x = np.linspace(0, 0.4, 1000)
y = np.linspace(0, 0.4, 1000)
X, Y = np.meshgrid(x, y)
Z = (0.6*X - 0.4*Y) -0.01 > 0.0

heatmap = plt.pcolormesh(X, Y, Z, shading="auto", cmap=tc.tol_cmap("sunset"))

X = np.array([d[0] for d in (*imdb_features["pos"], *imdb_features["neg"])])
Y = np.array([d[1] for d in (*imdb_features["pos"], *imdb_features["neg"])])
H = np.array([*(1 for _ in imdb_features["pos"]), *(0 for _ in imdb_features["neg"])])
plt.scatter(x=X, y=Y, c=H, cmap="viridis", s=5)

plt.show()

Pas si surprenant que nos résultats ne soient pas terribles…

La fonction logistique¶

$$σ(z) = \frac{1}{1 + e^{−z}} = \frac{1}{1 + \exp(−z)}$$

Elle permet de normaliser $z$ : $z$ peut être n'importe quel nombre entre $-∞$ et $+∞$, mais on aura toujours $0 < σ(z) < 1$, ce qui permet de l'interpréter facilement comme une vraisemblance. Autrement dit, $σ(z)$ sera proche de $1$ s'il paraît vraisemblable que $x$ appartienne à la classe $1$ et proche de $0$ sinon.

📈 Exo 📈¶

Tracer avec matplotlib la courbe représentative de la fonction logistique.

📈 Correction 📈¶

In [23]:
def logistic(z):
    return 1/(1+np.exp(-z))
In [24]:
%matplotlib inline
import matplotlib.pyplot as plt
x = np.linspace(-10, 10, 5000)
y = logistic(x)
plt.plot(x, y)
plt.xlabel("$x$")
plt.ylabel("$σ(x)$")
plt.title("Courbe représentative de la fonction logistique sur $[-10, 10]$")
plt.show()

Régression logistique¶

Formellement : on suppose qu'il existe une fonction $f$ qui prédit parfaitement les classes, donc telle que pour tout couple exemple/étiquette $(x, y)$ avec $y$ valant $0$ ou $1$, $f(x) = y$. On approcher cette fonction par une fonction $g$ de la forme

$$g(x) = σ(w⋅x+b)$$

Si on choisit les poids $w$ et le biais $b$ tels que $g$ soit la plus proche possible de $f$ sur notre ensemble d'apprentissage, on dit que $g$ est la régression logistique de $f$ sur cet ensemble.

Un classifieur logistique, c'est simplement un classifieur qui pour un exemple $x$ renvoie $0$ si $g(x) < 0.5$ et $1$ sinon. Il a exactement les mêmes capacités de discrimination qu'un classifieur linéaire (sa frontière de décision est la même et il ne sait donc pas prendre de décisions plus complexes), mais on peut interpréter la confiance qu'il a dans sa décision.

Par exemple voici la confiance que notre classifieur codé en dur a en ses décisions

In [25]:
def classifier_confidence(x):
    return logistic(affine_combination(x, np.array([0.6, -0.4]), -0.01))


g_x = classifier_confidence(doc_features)
display(g_x)
display(Markdown(f"Le classifieur est sûr à {g_x:.06%} que ce document est dans la classe $1$."))
display(Markdown(f"Autrement dit, d'après le classifieur, la classe $1$ a {g_x:.06%} de vraisemblance pour ce document"))
0.5135392425438052

Le classifieur est sûr à 51.353924% que ce document est dans la classe $1$.

Autrement dit, d'après le classifieur, la classe $1$ a 51.353924% de vraisemblance pour ce document

Quelle est la vraisemblance de la classe $0$ (review négative) ? Et bien le reste

In [26]:
1.0 - classifier_confidence(doc_features)
Out[26]:
0.48646075745619477

Comme l'exemple en question appartient bien à cette classe, ça signifie que notre classifieur et plutôt bon sur cet exemple. L'est-il sur le reste du corpus ?

In [27]:
pos_confidence = sum(classifier_confidence(doc) for doc in imdb_features["pos"])
print(f"Average confidence for 'pos': {pos_confidence/len(imdb_features['pos']):.02%}")
neg_confidence = sum(1-classifier_confidence(doc) for doc in imdb_features["neg"])
print(f"Average confidence for 'neg': {neg_confidence/len(imdb_features['neg']):.02%}")
print(f"Average confidence for the correct class: {(pos_confidence+neg_confidence)/(len(imdb_features['pos']) + len(imdb_features['neg'])):.02%}")
Average confidence for 'pos': 51.18%
Average confidence for 'neg': 49.80%
Average confidence for the correct class: 50.49%

Autrement dit, pour un exemple pris au hasard dans le corpus, la vraisemblance de sa classe telle que jugée par le classifieur sera de $50.49\%$. Un classifieur parfait obtiendrait $100\%$, un classifieur qui prendrait systématiquement la mauvaise décision $0\%$ et un classifieur aléatoire uniforme $50\%$ (puisque notre corpus a autant d'exemples de chaque classe).

Moralité : nos poids ne sont pas très bien choisis, et notre préoccupation dans la suite va être de chercher comment choisir des poids pour que la confiance moyenne de la classe correcte soit aussi haute que possible.

Fonction de coût¶

On a dit que notre objectif était

Chercher les poids $w$ et le biais $b$ tels que $g$ soit la plus proche possible de $f$ sur notre ensemble d'apprentissage

On formalise « être le plus proche possible » de la section précédente comme minimiser une certaine fonction de coût (loss) $L$ qui mesure l'erreur faite par le classifieur sur un exemple.

$$L(g(x), y) = \text{l'écart entre la classe prédite par $g$ pour $x$ et la classe correcte $y$}$$

Étant donné un ensemble de test $(x₁, y₁), …, (x_n, y_n)$, on estime l'erreur faite par le classifieur logistique $g$ pour chaque exemple $(x_i, y_i)$ comme le coût local $L(g(xᵢ), yᵢ)$ et son erreur sur tout l'ensemble de test par le coût global $\mathcal{L}$ :

$$\mathcal{L} = \sum_i L(g(xᵢ), yᵢ)$$

Plus $\mathcal{L}$ sera bas, meilleur sera notre classifieur.

Dans le cas de la régression logistique, on va s'inspirer de ce qu'on a vu dans la section précédente et utiliser la log-vraisemblance négative (negative log-likelihood) :

On définit la vraisemblance $V$ comme précédemment par $$ V(a, y) = \begin{cases} a & \text{si $y = 1$}\\ 1-a & \text{sinon} \end{cases} $$

Intuitivement, il s'agit de la vraisemblance affectée par le modèle à la classe correcte $y$. Il ne s'agit donc pas d'un coût, mais d'un gain (si sa valeur est haute, c'est que le modèle est bon)

La log-vraisemblance négative $L$ est alors définie par

$$L(a, y) = -\log(V(a, y))$$

Le $\log$ est là pour plusieurs raisons, calculatoires et théoriques1 et le $-$ à s'assurer qu'on a bien un coût (plus la valeur est basse, meilleur le modèle est).

1. Entre autres, comme pour *Naïve Bayes*, parce qu'une somme de $\log$-vraisemblance peut être vue comme le $\log$ de la probabilité d'une conjonction d'événements indépendants. Mais surtout parce qu'il rend la fonction de coût **convexe** par rapport à $w$.

Une interprétation possible : $L(a, y)$, c'est la surprise de $y$ au sens de la théorie de l'information. Autrement dit : si j'estime qu'il y a une probabilité $a$ d'observer la classe $y$, $L(a, y)$ mesure à quel point il serait surprenant d'observer effectivement $y$.

On peut vérifier qu'il s'agit bien d'un coût :

  • C'est un nombre positif
  • Si le classifieur prend une décision correcte avec une confiance parfaite le coût est nul :

    $$ \begin{cases}

      L(1.0, 1) = -\log(1.0) = 0\\
      L(0.0, 0) = -\log(1.0-0.0) = -\log(1.0) = 0
    

    \end{cases} $$

  • Si le classifieur prend une décision erronée avec une confiance parfaite le coût est infini :

    $$ \begin{cases}

      L(0.0, 1) = -\log(0.0) = +\infty\\
      L(1.0, 0) = -\log(1.0-1.0) = \log(0.0) = +\infty
    

    \end{cases} $$

On peut aussi vérifier facilement que $L(a, 1)$ est décroissant par rapport à $a$ et que $L(1-a, 0)$ est croissant par rapport à $a$. Autrement dit, plus le classifieur juge que la classe correcte est vraisemblable plus le coût $L$ est bas.

Enfin, on peut l'écrire $L$ en une ligne : pour un exemple $x$, le coût de l'exemple $(x, y)$ est

$$L(g(x), y) = -\log\left[g(x)×y + (1-g(x))×(1-y)\right]$$

C'est un trick, l'astuce c'est que comme $y$ vaut soit $0$ soit $1$, soit $y=0$, soit $1-y=0$ et donc la somme dans le $\log$ se simplifie dans tous les cas. Rien de transcendant là-dedans.

La formule diffère un peu de celle de Speech and Language Processing, mais les résultats sont les mêmes et celle-ci est mieux pour notre problème !

En fait la leur est la formule générale de l'entropie croisée pour des distributions de proba à support dans $\{0, 1\}$, ce qui est une autre intuition pour cette fonction de coût, mais ici elle nous complique la vie.

Une dernière façon de l'écrire en une ligne :

$$L(g(x), y) = -\log\left[g(x)\mathbb{1}_{y=1} + (1-g(x))\mathbb{1}_{y=0}\right]$$

📉 Exo 📉¶

Écrire une fonction qui prend en entrée

  • Un vecteur de features $x$ de taille $n$
  • Un vecteur de poids $w$ de taille $n$ et un biais $b$ (de taille $1$)
  • Une classe cible $y$ ($0$ ou $1$)

Et renvoie la log-vraisemblance négative du classifieur logistique de poids $(w, b)$ pour l'exemple $(x, y)$.

Servez-vous en pour calculer le coût du classifieur de l'exercise précédent sur le mini-corpus IMDB.

📉 Correction 📉¶

In [28]:
def logistic_negative_log_likelihood(x, w, b, y):
    g_x = logistic(affine_combination(x, w, b))
    if y == 1:
        correct_likelihood = g_x
    else:
        correct_likelihood = 1-g_x
    loss = -np.log(correct_likelihood)
    return loss
In [29]:
def loss_on_imdb(w, b, featurized_corpus):
    loss_on_pos = np.zeros(1)
    for doc_features in featurized_corpus["pos"]:
        loss_on_pos += logistic_negative_log_likelihood(
            doc_features, w, b, 1
        )
    loss_on_neg = np.zeros(1)
    for doc_features in featurized_corpus["neg"]:
        loss_on_neg += logistic_negative_log_likelihood(
            doc_features, w, b, 0
        )
    return loss_on_pos + loss_on_neg

Avec des compréhensions

In [30]:
def loss_on_imdb(w, b, featurized_corpus):
    loss_on_pos = sum(
        logistic_negative_log_likelihood(doc_features, w, b, 1)
        for doc_features in featurized_corpus["pos"]
    )
    loss_on_neg = sum(
        logistic_negative_log_likelihood(doc_features, w, b, 0)
        for doc_features in featurized_corpus["neg"]
    )
    return loss_on_pos + loss_on_neg

En version numériquement stable

In [31]:
import math
def loss_on_imdb(w, b, featurized_corpus):
    loss_on_pos = math.fsum(
        logistic_negative_log_likelihood(doc_features, w, b, 1).astype(float)
        for doc_features in featurized_corpus["pos"]
    )
    loss_on_neg = math.fsum(
        logistic_negative_log_likelihood(doc_features, w, b, 0).astype(float)
        for doc_features in featurized_corpus["neg"]
    )
    return np.array([loss_on_pos + loss_on_neg])
In [32]:
loss_on_imdb(np.array([0.6, -0.4]), -0.01, imdb_features)
Out[32]:
array([411.54449928])

Descente de gradient¶

Principe général¶

L'algorithme de descente de gradient est la clé de voute de l'essentiel des travaux en apprentissage artificiel moderne. Il s'agit d'un algorithme itératif qui étant donné un modèle paramétrisé et une fonction de coût (avec des hypothèses de régularité assez faibles) permet de trouver des valeurs des paramètres pour lesquelles la fonction de coût est minimal.

On ne va pas rentrer dans les détails de l'algorithme de descente de gradient stochastique, mais juste essayer de se donner quelques idées.

L'intuition à avoir est la suivante : si vous êtes dans une vallée et que vous voulez trouver rapidement le point le plus bas, une façon de faire est de chercher la direction vers laquelle la pente descend le plus vite, de faire quelques pas dans cette direction puis de recommencer. On parle aussi pour cette raison d'algorithme de la plus forte pente.

Clairement une condition pour que ça marche peu importe le point de départ, c'est que la vallée n'ait qu'un seul point localement le plus bas. Par exemple ça marche avec une vallée comme celle-ci

In [33]:
%matplotlib inline
import tol_colors as tc
import matplotlib.pyplot as plt
import numpy as np

fig = plt.figure(figsize=(20, 20), dpi=200)
ax = plt.axes(projection='3d')

r = np.linspace(0, 8, 100)
p = np.linspace(0, 2*np.pi, 100)
R, P = np.meshgrid(r, p)
Z = R**2 - 1

X, Y = R*np.cos(P), R*np.sin(P)

ax.plot_surface(X, Y, Z, cmap=tc.tol_cmap("sunset"), edgecolor="none", rstride=1, cstride=1)
ax.plot_wireframe(X, Y, Z, color='black')

plt.show()

Mais pas pour celle-là

In [34]:
%matplotlib inline
import tol_colors as tc
import matplotlib.pyplot as plt
import numpy as np

fig = plt.figure(figsize=(20, 20), dpi=200)
ax = plt.axes(projection='3d')

r = np.linspace(0, 8, 100)
p = np.linspace(0, 2*np.pi, 100)
R, P = np.meshgrid(r, p)
Z = -np.cos(R)/(1+0.5*R**2)

X, Y = R*np.cos(P), R*np.sin(P)

ax.plot_surface(X, Y, Z, cmap=tc.tol_cmap("sunset"), edgecolor="none", rstride=1, cstride=1)
#ax.plot_wireframe(X, Y, Z, color='black')

plt.show()

OK, mais comment on trouve la plus forte pente en pratique ? En une dimension il suffit de suivre l'opposé du nombre dérivé : https://uclaacm.github.io/gradient-descent-visualiser/#playground

En plus de dimensions, c'est plus compliqué, mais on peut s'en sortir en suivant le gradient qui est une généralisation du nombre dérivé : https://jackmckew.dev/3d-gradient-descent-in-python.html

Ce qui fait marcher la machine c'est que le gradient indique la direction dans laquelle la fonction croît le plus vite. Et que l'opposé du gradient indique la direction dans laquelle la fonction décroît le plus vite.

(localement)

Concrètement si on veut trouver $\theta$ tel que $f(\theta)$ soit minimale pour une certaine fonction $f$ dont le gradient est donné par grad_f ça donne l'algo suivant

def descent(grad_f, theta_0, learning_rate, n_steps):
    theta = theta_0
    for _ in range(n_steps):
        # On trouve la direction de plus grande pente
        steepest_direction = -grad_f(theta)
        # On fait quelques pas dans cette direction
        theta += learning_rate*steepest_direction
    return theta

Les hyperparamètres sont

  • theta_0 est notre point de départ, notre première estimation d'où se trouvera le minimum, que l'algorithme va raffiner. Évidemment si on a déjà une idée de vers où on pourrait le trouver, ça ira plus vite. Si on a aucune idée, on peut le prendre aléatoire.
  • learning_rate ou « taux d'apprentissage » : de combien on se déplace à chaque étape. Si on le prend grand on arrive vite vers la région du minimum, on mettra longtemps pour en trouver une approximation précise. Si on le prend petit, ça sera l'inverse.
  • n_steps est le nombre d'étapes d'optimisations. Dans un problème d'apprentissage, c'est aussi le nombre de fois où on aura parcouru l'ensemble d'apprentissage et on parle souvent d'epoch

Ici on se donne un nombre fixe d'epochs, une autre possibilité serait de s'arrêter quand on ne bouge plus trop, par exemple avec une condition comme

if np.max(grad_f(theta)) < 0.00001:
    break

dans la boucle et éventuellement avec une boucle infinie while True.

Point notation :

  • Le gradient de $f$ est souvent noté $\nabla f$ ou $\operatorname{grad}f$, voire $\vec\nabla f$ ou $\overrightarrow{\operatorname{grad}} f$ (pour dire que c'est un vecteur)
  • Si $θ=(θ_1, …, θ_n)$, autrement dit si $f$ est une fonction de $n$ variables, on note $\operatorname{grad}f = \left(\frac{∂f(θ)}{∂θ_1}, …, \frac{∂f(θ)}{∂θ_n}\right)$. Autrement dit $\frac{∂f(θ)}{∂θ_i}$, la dérivée partielle de $f(θ)$ par rapport à $θ_i$ est la $i$-ème coordonnées du gradient de $f$.
  • Le taux d'apprentissage est souvent noté $α$ ou $η$

Descente de gradient stochastique¶

Rappelez-vous, on a dit que notre fonction de coût, c'était

$$\mathcal{L} = \sum_i L(g(xᵢ), yᵢ)$$

et on cherche la valeur du paramètre $θ = (w_1, …, w_n, b)$ tel que $\mathcal{L}$ soit le plus petit possible.

On peut utilise la propriété d'additivité du gradient : pour deux fonctions $f$ et $g$, on a

$$\operatorname{grad}(f+g) = \operatorname{grad}f + \operatorname{grad}g$$

Donc ici

$$\operatorname{grad}\mathcal{L} = \sum_i \operatorname{grad}L(g(xᵢ), yᵢ)$$

Si on dispose d'une fonction grad_L qui, étant donnés $g(x_i)$ et $y_i$, renvoie $\operatorname{grad}L(g(x_i), y_i)$, l'algorithme de descente du gradient devient alors

def descent(train_set, theta_0, learning_rate, n_steps):
    theta = theta_0
    for _ in range(n_steps):
        w = theta[:-1]
        b = theta[-1]
        partial_grads = []
        for (x, y) in train_set:
            # On calcule g(x)
            g_x = logistic(np.inner(w,x)+b)
            # On calcule le gradient de L(g(x), y))
            partial_grads.append(grad_L(g_x, y))
        # On trouve la direction de plus grande pente
        steepest_direction = -np.sum(partial_grads)
        # On fait quelques pas dans cette direction
        theta += learning_rate*steepest_direction

    return theta

Pour chaque étape, on doit calculer tous les $g(x_i)$ et $\operatorname{grad}L(g(x_i), y_i)$. C'est très couteux, il doit y avoir moyen de faire mieux.

Si les $L(g(xᵢ), yᵢ)$ étaient indépendants, ce serait plus simple : on pourrait les optimiser séparément.

Ce n'est évidemment pas le cas : si on change $g$ pour que $g(x_0)$ soit plus proche de $y_0$, ça changera aussi la valeur de $g(x_1)$.

Mais on va faire comme si

C'est une approximation sauvage, mais après tout on commence à avoir l'habitude. On va donc suivre l'algo suivant

def descent(train_set, theta_0, learning_rate, n_steps):
    theta = theta_0
    for _ in range(n_steps):
        for (x, y) in train_set:
            w = theta[:-1]
            b = theta[-1]
            # On calcule g(x)
            g_x = logistic(np.inner(w,x)+b)
            # On trouve la direction de plus grande pente
            steepest_direction = -grad_L(g_x, y)
            # On fait quelques pas dans cette direction
            theta += learning_rate*steepest_direction

    return theta

Faites bien attention à la différence : au lieu d'attendre d'avoir calculé tous les $\operatorname{grad}L(g(x_i), y_i)$ avant de modifier $θ$, on va le modifier à chaque fois.

  • Avantage : on modifie beaucoup plus souvent le paramètre, si tout se passe bien, on devrait arriver à une bonne approximation très vite.
  • Inconvénient : il se pourrait qu'en essayant de faire baisser $L(g(x_0), y_0)$, on fasse augmenter $L(g(x_1), y_1)$.

Notre espoir ici c'est que cette situation n'arrivera pas, et qu'on bon paramètre pour un certain couple $(x, y)$ c'est un bon paramètres pour $tous$ les couples (exemple, classe).

Ce nouvel algorithme s'appelle l'algorithme de descente de gradient stochastique, et il est crucial pour nous, parce qu'on ne pourra en pratique quasiment jamais faire de descente de gradient globale.

Il ne nous reste plus qu'à savoir comment on calcule grad_L. On ne fera pas la preuve, mais on a

$$\frac{∂L(g(x), y)}{∂w_i} = (g(x)-y)x_i$$

et

$$\frac{∂L(g(x), y)}{∂b} = g(x)-y$$

Autrement dit on mettra à jour $w$ en calculant

$$w ← w -η×\operatorname{d}_wL(g(x), y) = w - η×(g(x)-y)x$$

$\operatorname{d}_wL(g(x), y) = \left(\frac{∂L(g(x), y)}{∂w_1}, …, \frac{∂L(g(x), y)}{∂w_n}\right)$ est la *différentielle partielle* de $L(g(x), y)$ par rapport à $w$.

Et $b$ en calculant

$$b ← b -η×\frac{∂L(g(x), y)}{∂b} = b - η×(g(x)-y)$$

🧐 Exo 🧐¶

1. Calculer le gradient¶

Reprendre la fonction qui calcule la fonction de coût, et la transformer pour qu'elle renvoie le gradient par rapport à $w$ et la dérivée partielle par rapport à $b$ en $(x, y)$.

In [35]:
def grad_L(x, w, b, y):
    grad = np.zeros(w.size+b.size)  # À vous !
    return grad

grad_L(np.array([5, 10]), np.array([0.6, -0.4]), np.array([-0.01]), 0)
Out[35]:
array([0., 0., 0.])

🧐 Correction 1 🧐¶

In [36]:
def grad_L(x, w, b, y):
    g_x = logistic(np.inner(w, x) + b)
    grad_w = (g_x - y)*x
    grad_b = g_x - y
    return np.append(grad_w, grad_b)
grad_L(np.array([5, 10]), np.array([0.6, -0.4]), np.array([-0.01]), 0)
Out[36]:
array([1.33489925, 2.66979851, 0.26697985])

2. Descendre le gradient¶

S'en servir pour apprendre les poids à donner aux features précédentes à l'aide du mini-corpus IMDB en utilisant l'algorithme de descente de gradient stochastique.

In [37]:
def descent(featurized_corpus, theta_0, learning_rate, n_steps):
    theta = theta_0
    for _ in range(n_steps):
        pass  # À vous !
    return 
descent(imdb_features, np.array([0.6, -0.4, 0.0]), 0.001, 100)

🧐 Correction 2 🧐¶

Version minimale

In [38]:
import random

def descent(featurized_corpus, theta_0, learning_rate, n_steps):
    train_set = [
        *((doc, 1) for doc in featurized_corpus["pos"]),
        *((doc, 0) for doc in featurized_corpus["neg"])
    ]
    theta = theta_0
    w = theta[:-1]
    b = theta[-1]
    
    for i in range(n_steps):
        # On mélange le corpus pour s'assurer de ne pas avoir d'abord tous
        # les positifs puis tous les négatifs
        random.shuffle(train_set)
        for j, (x, y) in enumerate(train_set):
            grad = grad_L(x, w, b, y)
            steepest_direction = -grad
            theta += learning_rate*steepest_direction
            w = theta[:-1]
            b = theta[-1]
    return (theta[:-1], theta[-1])

Avec du feedback pour voir ce qui se passe

In [39]:
def descent_with_logging(featurized_corpus, theta_0, learning_rate, n_steps):
    train_set = [
        *((doc, 1) for doc in featurized_corpus["pos"]),
        *((doc, 0) for doc in featurized_corpus["neg"])
    ]
    theta = theta_0
    theta_history = [theta_0.tolist()]
    w = theta[:-1]
    b = theta[-1]
    print("Epoch\tLoss\tAccuracy\tw\tb")
    print(f"Initial\t{loss_on_imdb(w, b, featurized_corpus).item()}\t{classifier_accuracy(w, b, featurized_corpus)}\t{w}\t{b}")
    
    for i in range(n_steps):
        # On mélange le corpus pour s'assurer de ne pas avoir d'abord tous
        # les positifs puis tous les négatifs
        random.shuffle(train_set)
        for j, (x, y) in enumerate(train_set):
            grad = grad_L(x, w, b, y)
            steepest_direction = -grad
            # Purement pour l'affichage
            loss = logistic_negative_log_likelihood(x, w, b, y)
            #print(f"step {i*len(train_set)+j} doc={x}\tw={w}\tb={b}\tloss={loss}\tgrad={grad}")
            theta += learning_rate*steepest_direction
            w = theta[:-1]
            b = theta[-1]
        theta_history.append(theta.tolist())
        epoch_train_loss = loss_on_imdb(w, b, featurized_corpus).item()
        epoch_train_accuracy = classifier_accuracy(w, b, imdb_features)
        print(f"{i}\t{epoch_train_loss}\t{epoch_train_accuracy}\t{w}\t{b}")
    return (theta[:-1], theta[-1]), theta_history

theta, theta_history = descent_with_logging(imdb_features, np.array([0.6, -0.4, -0.01]), 0.1, 100)
Epoch	Loss	Accuracy	w	b
Initial	411.5444992792534	0.6428571428571429	[ 0.6 -0.4]	-0.01
0	405.8386336123548	0.632890365448505	[ 1.21310828 -0.87566575]	-0.008855653028272105
1	400.72297144894566	0.6445182724252492	[ 1.7890891  -1.33248739]	-0.018491113051853164
2	403.45381560771926	0.5382059800664452	[ 2.29369692 -1.77953224]	-0.4545575321164322
3	395.6503914668431	0.5913621262458472	[ 2.82185236 -2.18052694]	-0.40458225163692835
4	387.1542096632066	0.6976744186046512	[ 3.33605168 -2.54590889]	-0.18933898754313291
5	383.5335167510666	0.6943521594684385	[ 3.81157641 -2.91899776]	-0.20178345072824427
6	383.5990265254297	0.6661129568106312	[ 4.23717375 -3.29101946]	-0.4479254419719547
7	381.68724776321716	0.6661129568106312	[ 4.66271698 -3.64014333]	-0.503672693305077
8	378.1126755414382	0.6561461794019934	[ 5.11628545 -3.93652054]	-0.06144114558073053
9	371.94336031631906	0.6943521594684385	[ 5.49276656 -4.27123132]	-0.2878481598708334
10	377.0888714324588	0.659468438538206	[ 5.83328624 -4.59486394]	-0.639959335664362
11	367.4894434204998	0.6943521594684385	[ 6.22800556 -4.86883491]	-0.35339269780994603
12	375.567034043822	0.6312292358803987	[ 6.61074189 -5.12690271]	0.011239346392446892
13	363.9553688029009	0.6910299003322259	[ 6.90176766 -5.42959361]	-0.43599557392810667
14	363.57611394304035	0.6827242524916943	[ 7.24011194 -5.6788869 ]	-0.2470192733948322
15	361.4648011716828	0.6960132890365448	[ 7.51408108 -5.94702192]	-0.5227629390451304
16	367.2646758585722	0.686046511627907	[ 7.77590512 -6.21267235]	-0.7587499017142372
17	357.80180915179835	0.6910299003322259	[ 8.08370856 -6.42585416]	-0.4283648067618686
18	361.17220368344556	0.6893687707641196	[ 8.32181749 -6.67016784]	-0.7009901151729937
19	356.8479050113456	0.6777408637873754	[ 8.60820508 -6.86263897]	-0.32122863075552194
20	355.76299026749166	0.6760797342192691	[ 8.86433692 -7.07933871]	-0.33396753399908485
21	353.88809611798035	0.6943521594684385	[ 9.07598057 -7.30462783]	-0.5616474470516962
22	361.52229649789604	0.6877076411960132	[ 9.27869262 -7.52543302]	-0.8521437331195502
23	351.68951583440196	0.6943521594684385	[ 9.53632493 -7.69592462]	-0.4940223296265582
24	355.58928625141937	0.6877076411960132	[ 9.77740493 -7.86762927]	-0.25379561139116746
25	354.16382940397307	0.686046511627907	[ 9.9787748  -8.05158201]	-0.2825482033087819
26	350.4847844924545	0.6810631229235881	[10.16341068 -8.23846746]	-0.4104350193690884
27	348.88467418195347	0.6910299003322259	[10.3427597  -8.41534238]	-0.5113536858984042
28	350.5297651766308	0.7026578073089701	[10.50621087 -8.59318403]	-0.7324578468725698
29	347.70952145004816	0.6993355481727574	[10.69344899 -8.74583103]	-0.5797088738676763
30	348.8468488245314	0.6794019933554817	[10.88093471 -8.89693503]	-0.4047648251935684
31	347.30973795466355	0.6794019933554817	[11.03664191 -9.0473922 ]	-0.47242954088770106
32	346.31800017420244	0.6943521594684385	[11.18410619 -9.1975602 ]	-0.5359324375685327
33	346.7810400202328	0.6843853820598007	[11.34476978 -9.33892568]	-0.46049557052123324
34	345.3839523976593	0.6960132890365448	[11.47544077 -9.48703414]	-0.6066153729437318
35	360.22999658942035	0.6578073089700996	[11.66836333 -9.60065723]	-0.1160201421212139
36	344.63376882564467	0.6960132890365448	[11.752985   -9.75985053]	-0.6164191284726369
37	345.3013309369365	0.6893687707641196	[11.87292074 -9.89305567]	-0.7292493346860729
38	348.5766319656864	0.7059800664451827	[ 11.98817793 -10.02887861]	-0.8746998067973188
39	344.0991588222455	0.6926910299003323	[ 12.14127279 -10.13430157]	-0.6993350219074241
40	343.4040167644193	0.6893687707641196	[ 12.27190047 -10.24983916]	-0.5964370096554679
41	344.8278642990712	0.6777408637873754	[ 12.39854015 -10.35401721]	-0.46400950274346203
42	344.1123670986865	0.6910299003322259	[ 12.48067758 -10.48855083]	-0.7688925147609866
43	343.0250052319926	0.6843853820598007	[ 12.61270175 -10.58978785]	-0.5533530407483938
44	342.54882635041736	0.6960132890365448	[ 12.70566046 -10.70599264]	-0.6899508823352177
45	342.3010916444732	0.6943521594684385	[ 12.81730658 -10.80409227]	-0.5949556100325547
46	342.06348946742776	0.6960132890365448	[ 12.91274438 -10.91805167]	-0.6020505153218079
47	342.6845031992872	0.6976744186046512	[ 12.98928784 -11.0290791 ]	-0.766776387513195
48	346.3817299449234	0.7026578073089701	[ 13.07195131 -11.13788626]	-0.9238525610315672
49	341.86792221355677	0.6843853820598007	[ 13.20162868 -11.20950088]	-0.5691615366218123
50	347.3924232317421	0.6877076411960132	[ 13.32276862 -11.29565601]	-0.348788133412944
51	341.50049481464237	0.6943521594684385	[ 13.36699463 -11.41426248]	-0.7435480628984266
52	341.58316406419203	0.6843853820598007	[ 13.47452191 -11.4979468 ]	-0.5609883166983907
53	341.6952468822618	0.6827242524916943	[ 13.55727624 -11.58235769]	-0.5462408142770604
54	347.8105027232011	0.6843853820598007	[ 13.66157319 -11.65806095]	-0.3342696549305616
55	342.49288431289597	0.6777408637873754	[ 13.7274938  -11.75805761]	-0.4950335711656657
56	340.3638203745493	0.6943521594684385	[ 13.78831755 -11.85341542]	-0.674178343895667
57	340.4555132098333	0.6993355481727574	[ 13.85394919 -11.93332803]	-0.7335152813456922
58	340.14646919384086	0.6943521594684385	[ 13.92686459 -12.01171091]	-0.6853596678524726
59	341.566679166022	0.6893687707641196	[ 13.98312106 -12.10402755]	-0.8363799964370819
60	339.97289571169165	0.6993355481727574	[ 14.06806959 -12.17183651]	-0.7089671532973639
61	340.64105665094854	0.6943521594684385	[ 14.12237015 -12.25259635]	-0.7977657678259574
62	345.065921347172	0.6810631229235881	[ 14.22876909 -12.30050641]	-0.3968576085389256
63	342.26279336750486	0.6777408637873754	[ 14.27871149 -12.37278133]	-0.485974384050354
64	339.80875522713984	0.6943521594684385	[ 14.32670254 -12.45548995]	-0.6302564323219602
65	340.3180969371982	0.6943521594684385	[ 14.36529467 -12.53658407]	-0.8057438465860458
66	341.8273876272441	0.6910299003322259	[ 14.41657367 -12.60786468]	-0.8893112431256451
67	343.00897229309254	0.6777408637873754	[ 14.51669149 -12.65065902]	-0.4535129034876174
68	339.6243037486744	0.6960132890365448	[ 14.53827296 -12.73183802]	-0.7697465915115199
69	347.78784859680286	0.6810631229235881	[ 14.63683174 -12.76638435]	-0.32773248028056673
70	345.5065114164821	0.7043189368770764	[ 14.61759927 -12.86755449]	-1.018608639704425
71	339.61099905143067	0.6960132890365448	[ 14.68896389 -12.9149759 ]	-0.7903813816689639
72	339.44650520762826	0.6960132890365448	[ 14.74263687 -12.97301769]	-0.7819519015494137
73	346.80946149115516	0.6843853820598007	[ 14.83570377 -13.00778788]	-0.3485987143461242
74	341.42389311829675	0.6744186046511628	[ 14.87652789 -13.07676639]	-0.5051371958075703
75	339.2908938960876	0.6893687707641196	[ 14.91208949 -13.14543455]	-0.6251888181286102
76	341.7470581964003	0.6760797342192691	[ 14.97182097 -13.19510295]	-0.4904550953852727
77	340.3591489142086	0.6777408637873754	[ 15.01068363 -13.25358547]	-0.549065298753805
78	338.75128011932725	0.6926910299003323	[ 15.02843593 -13.31638823]	-0.7054188058398451
79	339.3480072989758	0.6893687707641196	[ 15.08512644 -13.35929287]	-0.6082960058646443
80	339.5134235884971	0.6943521594684385	[ 15.09764994 -13.42378056]	-0.8268859585396283
81	341.3166703360356	0.6760797342192691	[ 15.17070291 -13.45861326]	-0.5022611402004007
82	341.87520184984663	0.6777408637873754	[ 15.20818921 -13.5076678 ]	-0.4806250071119895
83	353.3376470611691	0.6710963455149501	[ 15.26999106 -13.5428285 ]	-0.22093704217764273
84	338.8390277340502	0.6926910299003323	[ 15.25817146 -13.61295825]	-0.6415338574625433
85	340.76004043930936	0.6760797342192691	[ 15.30121    -13.65492191]	-0.5203398938453683
86	347.7824676021295	0.7043189368770764	[ 15.28152495 -13.74660905]	-1.1028751383454158
87	339.2880905613632	0.6943521594684385	[ 15.33921853 -13.77485032]	-0.8301695855230112
88	338.4201347349331	0.6960132890365448	[ 15.3879761  -13.80992306]	-0.7246458664840142
89	354.65879065809924	0.6694352159468439	[ 15.4720248  -13.81995619]	-0.19786248329221756
90	338.9439175407862	0.6893687707641196	[ 15.46079786 -13.88352975]	-0.6177761084695129
91	338.66555825246235	0.6993355481727574	[ 15.47092856 -13.93341492]	-0.7880095491163048
92	339.1637123517661	0.6843853820598007	[ 15.51767091 -13.96183887]	-0.5981618432678001
93	339.43036765151567	0.6960132890365448	[ 15.52012216 -14.02440826]	-0.8518220707630252
94	344.3072382699968	0.6810631229235881	[ 15.58938014 -14.03386067]	-0.4010928156295215
95	338.29725575078913	0.6960132890365448	[ 15.58680843 -14.08981693]	-0.6905011211026445
96	343.8844489387716	0.6794019933554817	[ 15.64054107 -14.09696118]	-0.41257902695352916
97	339.5121104726793	0.6960132890365448	[ 15.61865104 -14.16345478]	-0.8622857509806164
98	341.61161324368845	0.6760797342192691	[ 15.68558038 -14.18059799]	-0.48059497021842185
99	341.8606205393178	0.6794019933554817	[ 15.70930849 -14.20858304]	-0.4720659469072231

Un peu de visu supplémentaire :

Le trajet fait par $θ$ au cours de l'apprentissage

In [40]:
import numpy as np
import matplotlib.pyplot as plt


fig = plt.figure(figsize=(20, 20), dpi=200)
ax = plt.axes(projection='3d')

x, y, z = np.hsplit(np.array(theta_history), 3)

ax.plot(x.squeeze(), y.squeeze(), z.squeeze(), label="Trajet de $θ$ au cours de l'apprentissage")
ax.legend()

plt.show()
In [41]:
def make_vector_corpus(featurized_corpus):
    vector_corpus = np.stack([*featurized_corpus["pos"], *featurized_corpus["neg"]])
    vector_target = np.concatenate([np.ones(len(featurized_corpus["pos"])), np.zeros(len(featurized_corpus["neg"]))])
    return vector_corpus, vector_target

vector_corpus, vector_target = make_vector_corpus(imdb_features)
In [42]:
w1 = np.linspace(-50, 100, 200)
w2 = np.linspace(-100, 50, 200)
W1, W2 = np.meshgrid(w1, w2)
W = np.stack((W1, W2), axis=-1)
# Un peu de magie pour accélérer le calcul
confidence = logistic(
    np.einsum("ijn,kn->ijk", W, vector_corpus)
)
broadcastable_target = vector_target[np.newaxis, np.newaxis, :]
loss = -np.log(confidence * broadcastable_target + (1-confidence)*(1-broadcastable_target)).sum(axis=-1)
fig = plt.figure(figsize=(20, 20), dpi=200)
ax = plt.axes(projection='3d')
ax.set_xlim(-50, 100)
ax.set_ylim(-100, 50)
ax.set_zlim(0, 3000)

surf = ax.plot_surface(W1, W2, loss, cmap=tc.tol_cmap("sunset"), edgecolor="none", rstride=1, cstride=1, alpha=0.8)
fig.colorbar(surf, shrink=0.5, aspect=5)
ax.plot_wireframe(W1, W2, loss, color='black')

heatmap = ax.contourf(W1, W2, loss, offset=-30, cmap=tc.tol_cmap("sunset"))

plt.title("Paysage de la fonction de coût en fonction des valeurs de $w$ pour $b=0$")

plt.show()

Régression multinomiale¶

Un dernier point : on a vu dans tout ceci comment utiliser la régression logistique pour un problème de classification à deux classes. Comment on l'étend à $n$ classes ?

Réfléchissons déjà à quoi ressemblerait la sortie d'un tel classifieur :

Pour un problème à deux classes, le classifieur $g$ nous donne pour chaque exemple $x$ une estimation $g(x)$ de la vraisemblance de la classe $1$, et on a vu que la vraisemblance de la classe $0$ était nécessairement $1-g(x)$ pour que la somme des vraisemblances fasse 1.

On peut le présenter autrement : considérons le classifieur $f$ tel que pour tout exemple $x$

$$f(x) = (1-g(x), g(x))$$

$f$ nous donne un vecteur à deux coordonnées, $f_0(x)$ et $f_1(x)$, qui sont respectivement les vraisemblances des classes $0$ et $1$.

Pour un problème à $n$ classes, on va vouloir une vraisemblance par classe, on va donc procéder de la façon suivante :

On considère des poids $(w_1, b_1), …, (w_n, b_n)$. Ils définissent un classifieur linéaire.

En effet, si on considère les $z_i$ définis pour tout exemple $x$ par

$$ \begin{cases} z_1 = w_1⋅x + b_1\\ \vdots\\ z_n = w_n⋅x + b_1 \end{cases} $$

On peut choisir la classe $y$ à affecter à $x$ en prenant $y=\operatorname{argmax}\limits_i z_i$

Il reste à normaliser pour avoir des vraisemblances. Pour ça on utilise une fonction très importante : la fonction $\operatorname{softmax}$, définie ainsi :

$$\operatorname{softmax}(z_1, …, z_n) = \left(\frac{e^{z_1}}{\sum_i e^{z_i}}, …, \frac{e^{z_n}}{\sum_i e^{z_i}}\right)$$

Contrairement à la fonction logistique qui prenait un nombre en entrée et renvoyait un nombre, $\operatorname{softmax}$ prend en entrée un vecteur non-normalisé et renvoie un vecteur normalisé.

On définit enfin le classifieur logistique multinomial $f$ de la façon suivante : pour tout exemple $x$, on a

$$f(x) = \operatorname{softmax}(w_1⋅x+b_1, …, w_n⋅x+b_n) = \left(\frac{e^{w_1⋅x+b_1}}{\sum_i e^{w_i⋅x+b_i}}, …, \frac{e^{w_n⋅x+b_n}}{\sum_i e^{w_i⋅x+b_i}}\right)$$

et on choisit pour $x$ la classe

$$y = \operatorname{argmax}\limits_i f_i(x) = \operatorname{argmax}\limits_i \frac{e^{w_i⋅x+b_i}}{\sum_j e^{w_j⋅x+b_j}}$$

Comme la fonction exponentielle est croissante, ce sera la même classe que le classifieur linéaire précédent. Comme pour le cas à deux classe, la différence se fera lors de l'apprentissage. Je vous laisse aller en lire les détails dans Speech and Language Processing, mais l'idée est la même : on utilise la log-vraisemblance négative de la classe correcte comme fonction de coût, et on optimise les paramètres avec l'algo de descente de gradient stochastique.

Un dernier détail ?

Qu'est-ce qui se passe si on prend ce qu'on vient de voir pour $n=2$ ? Est-ce qu'on retombe sur le cas à deux classe vu précédemment ?

Oui, regarde : dans ce cas

$$ \begin{align} f_1(x) &= \frac{e^{w_1⋅x+b_1}}{e^{w_0⋅x+b_0}+e^{w_1⋅x+b_1}}\\ &= \frac{1}{ \frac{e^{w_0⋅x+b_0}}{e^{w_1⋅x+b_1}} + 1 }\\ &= \frac{1}{e^{(w_0⋅x+b_0)-(w_1⋅x+b_1)} + 1}\\ &= \frac{1}{1 + e^{(w_0-w_1)⋅x+(b_0-b_1)}}\\ &= σ((w_0-w_1)⋅x+(b_0-b_1)) \end{align} $$

Autrement dit, appliquer ce qu'on vient de voir pour le cas multinomial, si $n=2$, c'est comme appliquer ce qu'on a vu pour deux classes, avec $w=w_0-w_1$ et $b=b_0-b_1$.

La suite¶

Vous êtes arrivé⋅e⋅s au bout de ce cours et vous devriez avoir quelques idées de plusieurs concepts importants :

  • Le concept de classifieur linéaire
  • Le concept de fonction de coût
  • L'algorithme de descente de gradient stochastique
  • La fonction softmax

On reparlera de tout ça en temps utile. Pour la suite de vos aventures au pays des classifieurs logistiques, je vous recommande plutôt d'utiliser leur implémentation dans scikit-learn. Maintenant que vous savez comment ça marche, vous pouvez le faire la tête haute. Bravo !

Vous avez aussi découvert les premiers réseaux de neurones de ce cours et ce n'est pas rien !